提示信息

PyTorch 图像分类器

红灯区 已关注

创建一个图像分类器是许多深度学习任务中的常见需求,而PyTorch是一个非常强大的库,可以帮助你实现这一目标。下面是一个简单的PyTorch图像分类器的步骤指南:

准备环境

首先,你需要安装PyTorch和其他必要的库。你可以使用以下方式安装:

  1. pip install torch torchvision matplotlib

步骤步骤

  1. 导入库
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  1. 数据集准备

    使用torchvision来加载和预处理CIFAR-10数据集。

  1. transform = transforms.Compose(
  2. [transforms.ToTensor(),
  3. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  4. batch_size = 4
  5. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  6. download=True, transform=transform)
  7. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
  8. shuffle=True, num_workers=2)
  9. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  10. download=True, transform=transform)
  11. testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
  12. shuffle=False, num_workers=2)
  13. classes = ('plane', 'car', 'bird', 'cat',
  14. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  1. 构建神经网络

    创建一个简单的卷积神经网络。

  1. class Net(nn.Module):
  2. def __init__(self):
  3. super(Net, self).__init__()
  4. self.conv1 = nn.Conv2d(3, 6, 5)
  5. self.pool = nn.MaxPool2d(2, 2)
  6. self.conv2 = nn.Conv2d(6, 16, 5)
  7. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  8. self.fc2 = nn.Linear(120, 84)
  9. self.fc3 = nn.Linear(84, 10)
  10. def forward(self, x):
  11. x = self.pool(F.relu(self.conv1(x)))
  12. x = self.pool(F.relu(self.conv2(x)))
  13. x = x.view(-1, 16 * 5 * 5)
  14. x = F.relu(self.fc1(x))
  15. x = F.relu(self.fc2(x))
  16. x = self.fc3(x)
  17. return x
  18. net = Net()
  1. 定义损失函数和优化器
  1. criterion = nn.CrossEntropyLoss()
  2. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  1. 训练网络
  1. for epoch in range(2): # loop over the dataset multiple times
  2. running_loss = 0.0
  3. for i, data in enumerate(trainloader, 0):
  4. # get the inputs; data is a list of [inputs, labels]
  5. inputs, labels = data
  6. # zero the parameter gradients
  7. optimizer.zero_grad()
  8. # forward + backward + optimize
  9. outputs = net(inputs)
  10. loss = criterion(outputs, labels)
  11. loss.backward()
  12. optimizer.step()
  13. # print statistics
  14. running_loss += loss.item()
  15. if i % 2000 == 1999: # print every 2000 mini-batches
  16. print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
  17. running_loss = 0.0
  18. print('Finished Training')
  1. 测试网络
  1. dataiter = iter(testloader)
  2. images, labels = next(dataiter)
  3. # print images
  4. plt.imshow(torchvision.utils.make_grid(images).numpy().transpose((1, 2, 0)))
  5. print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
  6. outputs = net(images)
  7. _, predicted = torch.max(outputs, 1)
  8. print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(batch_size)))
  1. 模型保存

    保存模型的权重,以便之后加载和部署。

  1. PATH = './cifar_net.pth'
  2. torch.save(net.state_dict(), PATH)

以上代码提供了一个基本的图像分类流程。根据你的具体需求,你可以调整网络结构、优化器、学习率等超参数来提高模型性能。希望这个例子对你有所帮助!

    遇到难题? "AI大模型GPT4.0、GPT" 是你的私人解答专家! 点击按钮去提问......
红灯区 关注 已关注

最近一次登录:2024-11-20 09:18:49   

暂时还没有签名,请关注我或评论我的文章
×
免费图表工具,画流程图、架构图