代码拉取完成,页面将自动刷新
from utils import casualdata,parser_layer,lossfunc
from visualdl import LogWriter
from ErnieSEG import ErnieSeg
import paddle.nn as nn
from paddle.optimizer.lr import NoamDecay
from paddle.optimizer import AdamW
import paddle
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("param_save",help="模型保存路径")
parser.add_argument("-epoch",type=int,default = 300,help="epoch")
parser.add_argument("-batchsize",type=int,default =20 ,help="BATCHSIZE")
parser.add_argument("-lr",type=float,default = 3e-5,help="学习率")
args = parser.parse_args()
train_dataset = casualdata("./Dataset",batchsize=args.batchsize)
#parser = parser_layer(debug=False)
model = ErnieSeg()
loss = lossfunc()
# 学习率
lr = args.lr
# 学习率逐渐升高到基础学习率(即上面配置的lr)所需要的迭代数
warmup_steps = 5000
# AdamW优化器中使用的weight_decay的系数
weight_decay = 0.01
# 度裁剪允许的最大梯度值
max_grad_norm = 0.1
# 初始化Noam衰减学习率的策略
lr_scheduler = NoamDecay(1 / (warmup_steps * (lr**2)), warmup_steps)
# 对偏置和LayerNorm层不进行weight_decay策略
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# 初始化AdamW优化器
opt = AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params,
grad_clip=nn.ClipGradByGlobalNorm(max_grad_norm))
with LogWriter(logdir="./log/bs{}-lr{}".format(str(args.batchsize),str(args.lr))) as writer:
step = 0
for epoch in range(args.epoch):
evloss = {"att_loss":[],"rg_loss":[],"cl_loss":[],"total_loss":[]}
for sentences,pad_token,pad_mask,att_mask,rg_mask,class_mask in train_dataset:
pad_token,pad_mask = paddle.to_tensor(pad_token,dtype="int64"),paddle.to_tensor(pad_mask,dtype="float32")
pre_seg = model(pad_token,pad_mask)
att_loss,rg_loss,cl_loss = loss(pre_seg,att_mask,rg_mask,class_mask,1-pad_mask.numpy())
l = 1000*att_loss + rg_loss + cl_loss
opt.clear_grad()
l.backward()
opt.step()
lr_scheduler.step()
num = len(sentences)
step += num
evloss["att_loss"].append(att_loss.numpy()/num)
evloss["rg_loss"].append(rg_loss.numpy()/num)
evloss["cl_loss"].append(cl_loss.numpy()/num)
evloss["total_loss"].append(l.numpy()/num)
print("epoch:{},step:{},att_loss:{:.3f},rg_loss:{:.3f},cl_loss:{:.3f},total_loss:{:.3f}".format(epoch,step,np.mean(evloss["att_loss"]),\
np.mean(evloss["rg_loss"]),\
np.mean(evloss["cl_loss"]),\
np.mean(evloss["total_loss"])))
writer.add_scalar(tag="att_loss", step=step, value=np.mean(evloss["att_loss"]))
writer.add_scalar(tag="rg_loss", step=step, value=np.mean(evloss["rg_loss"]))
writer.add_scalar(tag="cl_loss", step=step, value=np.mean(evloss["cl_loss"]))
writer.add_scalar(tag="total_loss", step=step, value=np.mean(evloss["total_loss"]))
if epoch % 29 == 0:
paddle.save(model.state_dict(),args.param_save+"{}".format(str(epoch)))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。