1 Star 0 Fork 1

asker/gitee_classifier_free_diffusion

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train.py 10.28 KB
一键复制 编辑 原始数据 按行查看 历史
asker9527 提交于 2023-04-17 16:25 . “提交项目”
import os
import torch
import argparse
import itertools
import numpy as np
from unet import Unet
from tqdm import tqdm
import torch.optim as optim
from diffusion import GaussianDiffusion
from torchvision.utils import save_image
from utils import get_named_beta_schedule
from embedding import ConditionalEmbedding
from Scheduler import GradualWarmupScheduler
from dataloader_cifar import load_data, transback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
def train(params:argparse.Namespace):
assert params.genbatch % (torch.cuda.device_count() * params.clsnum) == 0 , 'please re-set your genbatch!!!'
# initialize settings
init_process_group(backend="nccl")
# get local rank for each process
local_rank = get_rank()
# set device
device = torch.device("cuda", local_rank)
# load data
dataloader, sampler = load_data(params.batchsize, params.numworkers)
# initialize models
net = Unet(
in_ch = params.inch,
mod_ch = params.modch,
out_ch = params.outch,
ch_mul = params.chmul,
num_res_blocks = params.numres,
cdim = params.cdim,
use_conv = params.useconv,
droprate = params.droprate,
dtype = params.dtype
)
cemblayer = ConditionalEmbedding(10, params.cdim, params.cdim).to(device)
# load last epoch
lastpath = os.path.join(params.moddir,'last_epoch.pt')
if os.path.exists(lastpath):
lastepc = torch.load(lastpath)['last_epoch']
# load checkpoints
checkpoint = torch.load(os.path.join(params.moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')
net.load_state_dict(checkpoint['net'])
cemblayer.load_state_dict(checkpoint['cemblayer'])
else:
lastepc = 0
betas = get_named_beta_schedule(num_diffusion_timesteps = params.T) # beta_t
diffusion = GaussianDiffusion(
dtype = params.dtype,
model = net,
betas = betas,
w = params.w,
v = params.v,
device = device
)
# DDP settings
diffusion.model = DDP(
diffusion.model,
device_ids = [local_rank],
output_device = local_rank
)
cemblayer = DDP(
cemblayer,
device_ids = [local_rank],
output_device = local_rank
)
# optimizer settings
optimizer = torch.optim.AdamW(
itertools.chain(
diffusion.model.parameters(),
cemblayer.parameters()
),
lr = params.lr,
weight_decay = 1e-4
)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer = optimizer,
T_max = params.epoch,
eta_min = 0,
last_epoch = -1
)
warmUpScheduler = GradualWarmupScheduler(
optimizer = optimizer,
multiplier = params.multiplier,
warm_epoch = params.epoch // 10,
after_scheduler = cosineScheduler,
last_epoch = lastepc
)
if lastepc != 0:
optimizer.load_state_dict(checkpoint['optimizer'])
warmUpScheduler.load_state_dict(checkpoint['scheduler'])
# training
print("Training on {}".format(device))
cnt = torch.cuda.device_count()
for epc in range(lastepc, params.epoch):
# turn into train mode
diffusion.model.train()
cemblayer.train()
sampler.set_epoch(epc)
# batch iterations
with tqdm(dataloader, dynamic_ncols=True, disable=(local_rank % cnt != 0)) as tqdmDataLoader:
for img, lab in tqdmDataLoader:
b = img.shape[0]
optimizer.zero_grad()
x_0 = img.to(device)
lab = lab.to(device)
cemb = cemblayer(lab)
cemb[np.where(np.random.rand(b)<params.threshold)] = 0
loss = diffusion.trainloss(x_0, cemb = cemb)
loss.backward()
optimizer.step()
tqdmDataLoader.set_postfix(
ordered_dict={
"epoch": epc + 1,
"loss: ": loss.item(),
"batch per device: ":x_0.shape[0],
"img shape: ": x_0.shape[1:],
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
}
)
warmUpScheduler.step()
# evaluation and save checkpoint
if (epc + 1) % params.interval == 0:
diffusion.model.eval()
cemblayer.eval()
# generating samples
# The model generate 80 pictures(8 per row) each time
# pictures of same row belong to the same class
all_samples = []
each_device_batch = params.genbatch // cnt
with torch.no_grad():
lab = torch.ones(params.clsnum, each_device_batch // params.clsnum).type(torch.long) \
* torch.arange(start = 0, end = params.clsnum).reshape(-1, 1)
lab = lab.reshape(-1, 1).squeeze()
lab = lab.to(device)
cemb = cemblayer(lab)
genshape = (each_device_batch , 3, 32, 32)
if params.ddim:
generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)
else:
generated = diffusion.sample(genshape, cemb = cemb)
img = transback(generated)
img = img.reshape(params.clsnum, each_device_batch // params.clsnum, 3, 32, 32).contiguous()
gathered_samples = [torch.zeros_like(img) for _ in range(get_world_size())]
all_gather(gathered_samples, img)
all_samples.extend([img for img in gathered_samples])
samples = torch.concat(all_samples, dim = 1).reshape(params.genbatch, 3, 32, 32)
if local_rank == 0:
save_image(samples, os.path.join(params.samdir, f'generated_{epc+1}_pict.png'), nrow = params.genbatch // params.clsnum)
# save checkpoints
checkpoint = {
'net':diffusion.model.module.state_dict(),
'cemblayer':cemblayer.module.state_dict(),
'optimizer':optimizer.state_dict(),
'scheduler':warmUpScheduler.state_dict()
}
torch.save({'last_epoch':epc+1}, os.path.join(params.moddir,'last_epoch.pt'))
torch.save(checkpoint, os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))
torch.cuda.empty_cache()
destroy_process_group()
def main():
# several hyperparameters for model
parser = argparse.ArgumentParser(description='test for diffusion model')
parser.add_argument('--batchsize',type=int,default=256,help='batch size per device for training Unet model')
parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')
parser.add_argument('--inch',type=int,default=3,help='input channels for Unet model')
parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')
parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')
parser.add_argument('--outch',type=int,default=3,help='output channels for Unet model')
parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')
parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')
parser.add_argument('--cdim',type=int,default=10,help='dimension of conditional embedding')
parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')
parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')
parser.add_argument('--dtype',default=torch.float32)
parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')
parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')
parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')
parser.add_argument('--epoch',type=int,default=1500,help='epochs for training')
parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')
parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')
parser.add_argument('--interval',type=int,default=20,help='epoch interval between two evaluations')
parser.add_argument('--moddir',type=str,default='model',help='model addresses')
parser.add_argument('--samdir',type=str,default='sample',help='sample addresses')
parser.add_argument('--genbatch',type=int,default=80,help='batch size for sampling process')
parser.add_argument('--clsnum',type=int,default=10,help='num of label classes')
parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')
parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')
parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')
parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=False,help='whether to use ddim')
parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')
args = parser.parse_args()
train(args)
if __name__ == '__main__':
# import torch.distributed as dist
# import os
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '5678'
#
# dist.init_process_group(backend='gloo', init_method='env://', rank=rank, world_size=world_size)
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/dingchi2036690098/gitee_classifier_free_diffusion.git
[email protected]:dingchi2036690098/gitee_classifier_free_diffusion.git
dingchi2036690098
gitee_classifier_free_diffusion
gitee_classifier_free_diffusion
master

搜索帮助