代码拉取完成,页面将自动刷新
同步操作将从 xijunjun/Pytorch-UNet 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import sys
from optparse import OptionParser
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from eval import eval_net
from unet import UNet
from utils import *
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
ids = get_ids(dir_img)
ids = split_ids(ids)
iddataset = split_train_val(ids, val_percent)
print('''
Starting training:
Epochs: {}
Batch size: {}
Learning rate: {}
Training size: {}
Validation size: {}
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu)))
N_train = len(iddataset['train'])
optimizer = optim.SGD(net.parameters(),
lr=lr, momentum=0.9, weight_decay=0.0005)
criterion = nn.BCELoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
# reset the generators
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
epoch_loss = 0
if 1:
val_dice = eval_net(net, val, gpu)
print('Validation Dice Coeff: {}'.format(val_dice))
for i, b in enumerate(batch(train, batch_size)):
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])
X = torch.FloatTensor(X)
y = torch.ByteTensor(y)
if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)
y_pred = net(X)
probs = F.sigmoid(y_pred)
probs_flat = probs.view(-1)
y_flat = y.view(-1)
loss = criterion(probs_flat, y_flat.float())
epoch_loss += loss.data[0]
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
loss.data[0]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if cp:
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
print('Checkpoint {} saved !'.format(epoch + 1))
if __name__ == '__main__':
parser = OptionParser()
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
help='number of epochs')
parser.add_option('-b', '--batch-size', dest='batchsize', default=10,
type='int', help='batch size')
parser.add_option('-l', '--learning-rate', dest='lr', default=0.1,
type='float', help='learning rate')
parser.add_option('-g', '--gpu', action='store_true', dest='gpu',
default=False, help='use cuda')
parser.add_option('-c', '--load', dest='load',
default=False, help='load file model')
(options, args) = parser.parse_args()
net = UNet(3, 1)
if options.load:
net.load_state_dict(torch.load(options.load))
print('Model loaded from {}'.format(options.load))
if options.gpu:
net.cuda()
cudnn.benchmark = True
try:
train_net(net, options.epochs, options.batchsize, options.lr,
gpu=options.gpu)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
print('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。