代码拉取完成,页面将自动刷新
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from nets.dcgan import discriminator, generator
from utils.utils import get_lr_scheduler, set_optimizer_lr
from utils.utils_fit import fit_one_epoch
if __name__ == "__main__":
#卷积通道数的设置
channel = 64
#图像大小的设置,如[128, 128]
input_shape = [64,64]
#训练参数设置
Init_Epoch = 0
Epoch = 500
batch_size = 1
#Init_lr 模型的最大学习率
Init_lr = 2e-3
# Min_lr 模型的最小学习率,默认为最大学习率的0.01
Min_lr = Init_lr * 0.01
# adam优化器
optimizer_type = "adam"
momentum = 0.5
weight_decay = 0
# lr_decay_type 使用到的学习率下降方式
lr_decay_type = "cos"
#------------------------------------------------------------------#
# save_dir 权值与日志文件保存的文件夹
#------------------------------------------------------------------#
save_dir = 'logs'
#------------------------------------------------------------------#
# num_workers 用于设置是否使用多线程读取数据
# 开启后会加快数据读取速度,但是会占用更多内存
# 内存较小的电脑可以设置为2或者0
#------------------------------------------------------------------#
num_workers = 0
#------------------------------#
# 每隔50个step保存一次图片
#------------------------------#
photo_save_step = 50
#------------------------------------------#
# 获得图片路径
#------------------------------------------#
annotation_path = "train_lines.txt"
#------------------------------------------------------#
# 设置用到的显卡
#------------------------------------------------------#
ngpus_per_node = torch.cuda.device_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
#------------------------------------------#
# 生成网络和评价网络
#------------------------------------------#
G_model = generator(channel, input_shape)
D_model = discriminator(channel, input_shape)
#------------------------------------------#
# 将训练好的模型重新载入
#------------------------------------------#
#----------------------#
# 获得损失函数
#----------------------#
BCE_loss = nn.BCEWithLogitsLoss()
scaler = None
G_model_train = G_model.train()
D_model_train = D_model.train()
cudnn.benchmark = True
G_model_train = torch.nn.DataParallel(G_model)
G_model_train = G_model_train.cuda()
D_model_train = torch.nn.DataParallel(D_model)
D_model_train = D_model_train.cuda()
with open(annotation_path) as f:
lines = f.readlines()
num_train = len(lines)
#------------------------------------------------------#
# Init_Epoch为起始世代
# Epoch总训练世代
#------------------------------------------------------#
if True:
#---------------------------------------#
# 根据optimizer_type选择优化器
#---------------------------------------#
G_optimizer = {
'adam' : optim.Adam(G_model_train.parameters(), lr=Init_lr, betas=(momentum, 0.999), weight_decay = weight_decay),
'sgd' : optim.SGD(G_model_train.parameters(), Init_lr, momentum = momentum, nesterov=True)
}[optimizer_type]
D_optimizer = {
'adam' : optim.Adam(D_model_train.parameters(), lr=Init_lr, betas=(momentum, 0.999), weight_decay = weight_decay),
'sgd' : optim.SGD(D_model_train.parameters(), Init_lr, momentum = momentum, nesterov=True)
}[optimizer_type]
# 获得学习率下降的公式
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr, Min_lr, Epoch)
# 判断每一个世代的长度
epoch_step = num_train // batch_size
# 开始模型训练-#
for epoch in range(Init_Epoch, Epoch):
set_optimizer_lr(G_optimizer, lr_scheduler_func, epoch)
set_optimizer_lr(D_optimizer, lr_scheduler_func, epoch)
fit_one_epoch(G_model_train, D_model_train, G_model, D_model, G_optimizer, D_optimizer, BCE_loss,
epoch, epoch_step, gen, Epoch, scaler, save_dir, photo_save_step, local_rank)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。