代码拉取完成,页面将自动刷新
# 模型库
import torch
import torchvision
# 进度条显示
from tqdm import tqdm
# 绘图库
import matplotlib.pyplot as plt
# 命令行参数获取
import utils.parameters
# 网络结构
from models.net import Net
if __name__ == "__main__":
# 如果网络能在GPU中训练,就使用GPU;否则使用CPU进行训练
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 定义 图像预处理器 对象
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(), # 将图片转换为张量(多维数组,存储浮点数)
torchvision.transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化处理
]
)
# 读取命令行参数
params_parser = utils.parameters.get_train_args()
# 设置批大小和训练轮数
# 数据集的图片按照批大小分批,训练完所有批次后,算为一轮
BATCH_SIZE = params_parser.batch_size
EPOCHS = params_parser.epochs
# 是否只是下载数据集
IS_DOWNLOAD_DATASET = params_parser.download_dataset
# 加载训练和测试数据
train_dataset = torchvision.datasets.MNIST(
"./data/", train=True, transform=transform, download=IS_DOWNLOAD_DATASET
)
test_dataset = torchvision.datasets.MNIST(
"./data/", train=False, transform=transform, download=IS_DOWNLOAD_DATASET
)
if IS_DOWNLOAD_DATASET:
print("数据集下载完成!")
exit()
# 建立数据迭代器,随机加载
# 装载训练集对象
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
# 装载测试集
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True
)
# DataLoader容器为封装后的线性表,当成链表用即可
# 每个元素是一个元组,包含一个批次的数据和标签,以张量形式存储
# 共ceil(len/BATCH_SIZE)个元组,每个元组包含一个批次的数据和标签
# 元素的数据结构如下:
# ( 图像张量, 标签张量 )
# 图像张量数据结构:
# double[批大小][通道数][图片高度][图片宽度]
# 标签张量数据结构:
# double[批大小]
# 是否只是查看数据的形状
if params_parser.show_datashape:
# 获取第一个批次的数据及其对应的索引。
# `example_data` 是图像张量,`example_targets` 是对应的标签
(example_data, example_targets) = next(iter(train_loader))
# 创建一个新的图形对象用于绘图
fig = plt.figure()
# 使用 for 循环绘制当前批次的前 6 个图像
for i in range(6):
# 在 2x3 的网格中创建子图,并指定当前子图的位置
plt.subplot(2, 3, i + 1)
# 调整子图之间的间距以避免重叠
plt.tight_layout()
# 显示图像,使用灰度颜色映射并且不使用插值
plt.imshow(example_data[i][0], cmap="gray", interpolation="none")
# 设置子图标题为该样本的真实标签
plt.title("Label: {}".format(example_targets[i]))
# 隐藏 x 轴和 y 轴的刻度
plt.xticks([])
plt.yticks([])
# 显示所有绘制的图像。
plt.show()
# 打印出当前批次数据的形状,了解数据的维度信息。
print("数据形状(批次大小, 通道数, 图片高度, 图片宽度):", example_data.shape)
exit()
# 构建模型实例
net = Net()
# 设置网络实例在设备上运行
net = net.to(device)
# 交叉熵损失函数
loss_fun = torch.nn.CrossEntropyLoss()
# Adam优化器
optimizer = torch.optim.Adam(net.parameters())
# 训练网络
# 记录训练过程中的损失和准确率
history = {"Test Loss": [], "Test Accuracy": []}
# 训练循环
for epoch in range(1, EPOCHS + 1):
# 进度条对象,遍历一遍train_loader就是一轮(epoch)
process_bar = tqdm(train_loader, unit="step")
# 切换网络为训练模式
net.train(True)
# 每次循环都能训练一个批次,进度条将在循环中更新
for train_imgs, labels in process_bar:
# 将 图片张量 和 标签张量 设置到设备上
train_imgs = train_imgs.to(device)
labels = labels.to(device)
# 清零梯度,防止累积
net.zero_grad()
# 前向传播一个批次的图像数据
outputs = net(train_imgs)
# 计算本批次的损失
loss = loss_fun(outputs, labels)
# 获取预测结果,dim=1表示取第2维的最大值,也就是选出可能性最大的那个类别
predictions = torch.argmax(outputs, dim=1)
# 计算准确率,将批次内预测正确的样本数除以总样本数
accuracy = torch.true_divide(
torch.sum(predictions == labels), labels.shape[0]
)
# 反向传播,计算梯度
loss.backward()
# 更新参数
optimizer.step()
# 更新进度条显示
# 显示当前批次的损失和准确率
process_bar.set_description(
"[%d/%d] Loss: %.4f, Acc: %.4f, Progress"
% (epoch, EPOCHS, loss.item(), accuracy.item())
)
# 训练完最后一个批次后,使用测试集评估模型效果
correct, total_loss = 0, 0
# 切换网络为推理模式
net.train(False)
# 在作用域内不计算梯度,节省内存
with torch.no_grad():
process_bar = tqdm(test_loader, unit="step")
# 执行一轮测试
for i, (test_imgs, labels) in enumerate(process_bar):
# 将 图片张量 和 标签张量 设置到设备上
test_imgs = test_imgs.to(device)
labels = labels.to(device)
# 前向传播一个批次的图像数据
outputs = net(test_imgs)
loss = loss_fun(outputs, labels)
predictions = torch.argmax(outputs, dim=1)
# 求损失
total_loss += loss
# 累加正确个数
correct += torch.sum(predictions == labels)
# 显示进度
process_bar.set_description("Testing, Progress")
# 计算测试集的准确率,正确个数除以总样本数
test_accuracy = torch.true_divide(correct, (BATCH_SIZE * len(test_loader)))
# 计算每个批次的平均损失
test_loss = torch.true_divide(total_loss, len(test_loader))
# 加入训练过程中的损失和准确率列表中
history["Test Loss"].append(test_loss.item())
history["Test Accuracy"].append(test_accuracy.item())
# 显示最终结果
print(
"Epoch[%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f"
% (
epoch,
loss.item(),
accuracy.item(),
test_loss.item(),
test_accuracy.item(),
)
)
# 是否静默保存模型
if params_parser.quiet_save_model:
print("Test Loss:", history["Test Loss"])
print("Test Accuracy:", history["Test Accuracy"])
torch.save(net.state_dict(), "./checkpoints/" + params_parser.save_model_name)
exit()
# 绘制训练过程中的损失和准确率曲线
# 对测试Loss进行可视化
plt.plot(history["Test Loss"], label="Test Loss")
plt.legend(loc="best")
plt.grid(True)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
# 对测试准确率进行可视化
plt.plot(history["Test Accuracy"], color="red", label="Test Accuracy")
plt.legend(loc="best")
plt.grid(True)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()
print("save model?(y/n)")
is_save_model = input()
if is_save_model == "y":
print("saving model...")
torch.save(net.state_dict(), "./checkpoints/" + params_parser.save_model_name)
print("save model at:", "/checkpoints/" + params_parser.save_model_name)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。