Pytorch:卷积神经网络简单案例
发布日期:2025-06-18 23:43:29
浏览次数:3
分类:精选文章
本文共 3616 字,大约阅读时间需要 12 分钟。
PyTorch与torchvision的结合应用:CIFAR-10数据集的训练与测试
本文将介绍如何利用PyTorch和torchvision实现CIFAR-10数据集的高效训练与测试,并对模型的性能进行详细分析。
数据集准备与预处理
首先,我们需要准备CIFAR-10数据集。通过torchvision,我们可以直接下载并加载数据集。为了确保模型的泛化能力,需要对图像数据进行标准化处理。具体来说,我们采用如下预处理流程:
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
训练集和测试集的加载
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
模型设计与训练
本文设计了一个简单的卷积神经网络(CNN)作为模型架构:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
损失函数与优化器的选择
采用交叉熵损失函数作为模型训练的目标函数,优化器选择随机梯度下降(SGD)算法:
criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
模型的训练过程
for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if (i + 1) % 2000 == 0: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0print("Finished Training") 模型的测试与评估
在测试阶段,我们首先加载测试数据集,并对模型的预测结果进行分析:
with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total)) 分类级别的准确率分析
为了更细致地了解模型在不同类别上的表现,我们对每个类别的准确率进行了统计:
class_correct = list(0. for _ in range(10))class_total = list(0. for _ in range(10))with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels) for i in range(4): class_correct[labels[i]] += c[i].item() class_total[labels[i]] += 1for i in range(10): print('Accuracy of %s : %2d%%' % (classes[i], 100 * class_correct[i] / class_total[i])) GPU加速训练
为了充分发挥PyTorch的优势,我们可以将训练过程迁移到GPU上进行加速:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net.to(device)# 在训练过程中:for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if (i + 1) % 2000 == 0: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0print("Finished Training") 以上是本文的完整实现过程及结果分析,涵盖了从数据集准备到模型训练与测试的全过程,并对模型的性能进行了详细评估。
发表评论
最新留言
第一次来,支持一个
[***.219.124.196]2026年05月28日 16时56分16秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
PHP版本升级5.4手记
2023-03-01
php版本升级总结
2023-03-01
php版本微信公众号开发
2023-03-01
php版的微信公众号开发演示
2023-03-01
php生成html文件的多种方法介绍
2023-03-01
php生成二维码到图片上
2023-03-01
php生成二维码并下载图片(适应于框架)
2023-03-01
PHP生成及获取JSON文件的方法
2023-03-01
PHP生成唯一不重复的编号
2023-03-01
PHP生成器-动态生成内容的数组
2023-03-01
PHP的ip2long和long2ip升级函数
2023-03-01
php的web路径获取
2023-03-01
php的一些小笔记--字符串
2023-03-01
php的几种运行模式CLI、CGI、FastCGI、mod_php
2023-03-01
php的四大特性八大优势
2023-03-01
RabbitMQ
2023-03-01
PHP的威胁函数与PHP代码审计实战
2023-03-01
PHP的引用举例
2023-03-01