{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true, "pycharm": { "name": "#%% md\n" } }, "source": [ "# 学习微调已经预训练好的网络\n", "\n", "调整一个预训练好的VGG16网络,使其能够分类kaggle数据库中的猴子数据。" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PACKAGE LOADING SUCCESSFUL...\n", "\n", "\n" ] } ], "source": [ "%matplotlib inline\n", "\n", "import time\n", "\n", "import PIL\n", "import sklearn\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torchviz\n", "import torchvision\n", "import torch.utils.data as torch_data\n", "\n", "sns.set()\n", "plt.rc('font', family='SimHei')\n", "plt.rc('axes', unicode_minus=False)\n", "assert torch.cuda.is_available(), Exception('CUDA IS NOT AVAILABLE!')\n", "print('PACKAGE LOADING SUCCESSFUL...\\n\\n')\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 加载预训练的模型用以定义自己的模型" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "从本地加载模型\n", "MyVGG16(\n", " (vgg): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (6): ReLU(inplace=True)\n", " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (8): ReLU(inplace=True)\n", " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (11): ReLU(inplace=True)\n", " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (13): ReLU(inplace=True)\n", " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (15): ReLU(inplace=True)\n", " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (18): ReLU(inplace=True)\n", " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (20): ReLU(inplace=True)\n", " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (22): ReLU(inplace=True)\n", " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (25): ReLU(inplace=True)\n", " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (27): ReLU(inplace=True)\n", " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (29): ReLU(inplace=True)\n", " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (classifier): Sequential(\n", " (0): Linear(in_features=25088, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Dropout(p=0.5, inplace=False)\n", " (3): Linear(in_features=512, out_features=256, bias=True)\n", " (4): ReLU()\n", " (5): Dropout(p=0.5, inplace=False)\n", " (6): Linear(in_features=256, out_features=10, bias=True)\n", " (7): Softmax(dim=1)\n", " )\n", ")\n" ] } ], "source": [ "class MyVGG16(torch.nn.Module):\n", " def __init__(self):\n", " super(MyVGG16, self).__init__()\n", "\n", " try:\n", " vgg16 = torchvision.models.vgg16(pretrained=False)\n", " vgg16.load_state_dict(torch.load('E:/MyResources/PyTorch深度学习入门与实战/Torch预训练模型/vgg16-397923af.pth'))\n", " print('从本地加载模型')\n", " except FileNotFoundError as e:\n", " vgg16 = torchvision.models.vgg16(pretrained=True)\n", " print('从网络获取模型')\n", "\n", " vgg = vgg16.features\n", " for p in vgg.parameters(): # 冻结特征提取层的参数,不进行更新\n", " p.requires_grad_(False)\n", "\n", " self.vgg = vgg\n", " self.classifier = torch.nn.Sequential(\n", " torch.nn.Linear(25088, 512),\n", " torch.nn.ReLU(),\n", " torch.nn.Dropout(p=0.5),\n", " torch.nn.Linear(512, 256),\n", " torch.nn.ReLU(),\n", " torch.nn.Dropout(p=0.5),\n", " torch.nn.Linear(256, 10),\n", " torch.nn.Softmax(dim=1),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.vgg(x)\n", " x = torch.reshape(x, (x.shape[0], -1))\n", " return self.classifier(x)\n", "\n", "\n", "my_vgg16 = MyVGG16()\n", "print(my_vgg16)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 训练数据的准备" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "len(train_data.targets)=1097, len(validation_data.targets)=272\n" ] } ], "source": [ "train_transform = torchvision.transforms.Compose([\n", " torchvision.transforms.RandomResizedCrop(224),\n", " torchvision.transforms.RandomHorizontalFlip(p=0.5),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", "])\n", "\n", "validation_transform = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(256),\n", " torchvision.transforms.CenterCrop(224),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", "])\n", "\n", "train_data = torchvision.datasets.ImageFolder(\n", " root='E:/MyResources/PyTorch深度学习入门与实战/程序(有数据)/data/chap6/10-monkey-species/training',\n", " transform=train_transform,\n", ")\n", "\n", "validation_data = torchvision.datasets.ImageFolder(\n", " root='E:/MyResources/PyTorch深度学习入门与实战/程序(有数据)/data/chap6/10-monkey-species/validation',\n", " transform=validation_transform,\n", ")\n", "\n", "train_loader = torch_data.DataLoader(\n", " dataset=train_data,\n", " batch_size=32,\n", " shuffle=True,\n", " num_workers=2,\n", ")\n", "\n", "validation_loader = torch_data.DataLoader(\n", " dataset=validation_data,\n", " batch_size=32,\n", " shuffle=True,\n", " num_workers=2,\n", ")\n", "\n", "print(f'{len(train_data.targets)=}, {len(validation_data.targets)=}')\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 开始微调网络" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "第1轮训练,耗时19.502秒。\n", "训练集损失:2.147,验证集损失:1.706,训练集精度:0.321,验证集精度:0.835。\n", "-------------------- \n", "\n", "第2轮训练,耗时18.160秒。\n", "训练集损失:1.819,验证集损失:1.536,训练集精度:0.678,验证集精度:0.952。\n", "-------------------- \n", "\n", "第3轮训练,耗时20.053秒。\n", "训练集损失:1.702,验证集损失:1.508,训练集精度:0.784,验证集精度:0.974。\n", "-------------------- \n", "\n", "第4轮训练,耗时18.104秒。\n", "训练集损失:1.645,验证集损失:1.485,训练集精度:0.834,验证集精度:0.978。\n", "-------------------- \n", "\n", "第5轮训练,耗时18.577秒。\n", "训练集损失:1.606,验证集损失:1.484,训练集精度:0.872,验证集精度:0.978。\n", "-------------------- \n", "\n", "第6轮训练,耗时18.712秒。\n", "训练集损失:1.601,验证集损失:1.483,训练集精度:0.870,验证集精度:0.985。\n", "-------------------- \n", "\n", "第7轮训练,耗时19.952秒。\n", "训练集损失:1.582,验证集损失:1.478,训练集精度:0.891,验证集精度:0.985。\n", "-------------------- \n", "\n", "第8轮训练,耗时18.402秒。\n", "训练集损失:1.587,验证集损失:1.477,训练集精度:0.881,验证集精度:0.985。\n", "-------------------- \n", "\n", "第9轮训练,耗时18.641秒。\n", "训练集损失:1.580,验证集损失:1.479,训练集精度:0.884,验证集精度:0.989。\n", "-------------------- \n", "\n", "第10轮训练,耗时19.492秒。\n", "训练集损失:1.546,验证集损失:1.481,训练集精度:0.930,验证集精度:0.978。\n", "-------------------- \n", "\n", "训练总计耗时:189.59948468208313秒。\n" ] } ], "source": [ "net = my_vgg16.to('cuda')\n", "optimizer = torch.optim.Adam(net.parameters(), lr=0.0003)\n", "loss_func = torch.nn.CrossEntropyLoss().to('cuda')\n", "\n", "since = time.time()\n", "train_loss_all, train_accuracy_all = [], []\n", "validation_loss_all, validation_accuracy_all = [], []\n", "\n", "for epoch in range(10):\n", " epoch_since = time.time()\n", " train_loss, validation_loss, train_accuracy, validation_accuracy = 0, 0, 0, 0\n", "\n", " net.train()\n", " for step, (X, y) in enumerate(train_loader):\n", " X, y = X.to('cuda'), y.to('cuda')\n", " out = net(X)\n", " predict = torch.argmax(out, dim=1)\n", " loss = loss_func(out, y)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " train_loss += loss.item() * X.shape[0]\n", " train_accuracy += torch.sum(predict == y)\n", "\n", " net.eval()\n", " for step, (X, y) in enumerate(validation_loader):\n", " X, y = X.to('cuda'), y.to('cuda')\n", " out = net(X)\n", " predict = torch.argmax(out, dim=1)\n", " loss = loss_func(out, y)\n", " validation_loss += loss.item() * X.shape[0]\n", " validation_accuracy += torch.sum(predict == y)\n", "\n", " train_loss_all.append(train_loss / len(train_data.targets))\n", " train_accuracy_all.append(train_accuracy / len(train_data.targets))\n", " validation_loss_all.append(validation_loss / len(validation_data.targets))\n", " validation_accuracy_all.append(validation_accuracy / len(validation_data.targets))\n", "\n", " print(f'第{epoch + 1}轮训练,耗时{time.time() - epoch_since:.3f}秒。')\n", " print(f'训练集损失:{train_loss_all[-1]:.3f},验证集损失:{validation_loss_all[-1]:.3f},'\n", " f'训练集精度:{train_accuracy_all[-1]:.3f},验证集精度:{validation_accuracy_all[-1]:.3f}。')\n", " print('--' * 10, '\\n')\n", "\n", "print(f'训练总计耗时:{time.time() - since}秒。')\n", "del net, optimizer, loss_func, since, epoch, epoch_since, train_loss, validation_loss, train_accuracy, validation_accuracy,\\\n", " step, X, y, out, predict, loss\n", "torch.cuda.empty_cache()\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 可视化一下训练的结果" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "<matplotlib.legend.Legend at 0x246c841ce80>" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": "<Figure size 720x576 with 1 Axes>", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "<Figure size 720x576 with 1 Axes>", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 8))\n", "plt.plot(train_loss_all, 'ro-', label='训练集损失')\n", "plt.plot(validation_loss_all, 'gs-', label='测试集损失')\n", "plt.legend()\n", "\n", "plt.figure(figsize=(10, 8))\n", "plt.plot(train_accuracy_all, 'ro-', label='训练集精度')\n", "plt.plot(validation_accuracy_all, 'gs-', label='测试集精度')\n", "plt.legend()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }