发布时间:2024-09-03 23:04:47
本内容由, 集智数据集收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性,请勿用于商业用途。
LeNet是最早的卷积神经网络之一,最初由YannLeCun在1998年设计,用于手写数字识别。LeNet的原始版本主要用于识别邮政编码中的手写数字,但它也是现代卷积神经网络架构的基础之一。 下面是一个使用PyTorch实现的简化版LeNet网络的例子。我们将使用MNIST数据集作为示例,因为LeNet最初是为此类任务设计的。MNIST数据集包含28x28像素的手写数字图片。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义LeNet模型
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5) # 输入通道数为1,输出通道数为6,卷积核大小为5x5
self.pool = nn.MaxPool2d(2, 2) # 最大池化层,窗口大小为2x2
self.conv2 = nn.Conv2d(6, 16, 5) # 输入通道数为6,输出通道数为16
self.fc1 = nn.Linear(16 * 4 * 4, 120) # 全连接层
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # 输出类别数为10
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化模型
model = LeNet()
# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100,
shuffle=True, num_workers=2)
testset = datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}')
print('Finished Training')
请注意,上面的代码中使用了 transforms.Normalize
对输入数据进行了标准化,这对于提高模型的训练效果是非常重要的。此外,代码中还包括了一个简单的训练循环,它会打印出每个 epoch 的平均损失值。
这个模型非常适合于手写数字识别任务,并且可以在 MNIST 数据集上运行得非常好。如果你想要在更复杂的图像数据集上使用类似的设计,可能需要对网络结构做一些调整,比如增加更多的卷积层或全连接层。
本站将定期更新分享一些python机器学习的精选代码