1 Star 12 Fork 1

oscarlin/ChineseCasualExtraction

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 3.53 KB
一键复制 编辑 原始数据 按行查看 历史
oscarlin 提交于 2021-04-19 10:11 . update train.py.
from utils import casualdata,parser_layer,lossfunc
from visualdl import LogWriter
from ErnieSEG import ErnieSeg
import paddle.nn as nn
from paddle.optimizer.lr import NoamDecay
from paddle.optimizer import AdamW
import paddle
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("param_save",help="模型保存路径")
parser.add_argument("-epoch",type=int,default = 300,help="epoch")
parser.add_argument("-batchsize",type=int,default =20 ,help="BATCHSIZE")
parser.add_argument("-lr",type=float,default = 3e-5,help="学习率")
args = parser.parse_args()
train_dataset = casualdata("./Dataset",batchsize=args.batchsize)
#parser = parser_layer(debug=False)
model = ErnieSeg()
loss = lossfunc()
# 学习率
lr = args.lr
# 学习率逐渐升高到基础学习率(即上面配置的lr)所需要的迭代数
warmup_steps = 5000
# AdamW优化器中使用的weight_decay的系数
weight_decay = 0.01
# 度裁剪允许的最大梯度值
max_grad_norm = 0.1
# 初始化Noam衰减学习率的策略
lr_scheduler = NoamDecay(1 / (warmup_steps * (lr**2)), warmup_steps)
# 对偏置和LayerNorm层不进行weight_decay策略
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
# 初始化AdamW优化器
opt = AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params,
grad_clip=nn.ClipGradByGlobalNorm(max_grad_norm))
with LogWriter(logdir="./log/bs{}-lr{}".format(str(args.batchsize),str(args.lr))) as writer:
step = 0
for epoch in range(args.epoch):
evloss = {"att_loss":[],"rg_loss":[],"cl_loss":[],"total_loss":[]}
for sentences,pad_token,pad_mask,att_mask,rg_mask,class_mask in train_dataset:
pad_token,pad_mask = paddle.to_tensor(pad_token,dtype="int64"),paddle.to_tensor(pad_mask,dtype="float32")
pre_seg = model(pad_token,pad_mask)
att_loss,rg_loss,cl_loss = loss(pre_seg,att_mask,rg_mask,class_mask,1-pad_mask.numpy())
l = 1000*att_loss + rg_loss + cl_loss
opt.clear_grad()
l.backward()
opt.step()
lr_scheduler.step()
num = len(sentences)
step += num
evloss["att_loss"].append(att_loss.numpy()/num)
evloss["rg_loss"].append(rg_loss.numpy()/num)
evloss["cl_loss"].append(cl_loss.numpy()/num)
evloss["total_loss"].append(l.numpy()/num)
print("epoch:{},step:{},att_loss:{:.3f},rg_loss:{:.3f},cl_loss:{:.3f},total_loss:{:.3f}".format(epoch,step,np.mean(evloss["att_loss"]),\
np.mean(evloss["rg_loss"]),\
np.mean(evloss["cl_loss"]),\
np.mean(evloss["total_loss"])))
writer.add_scalar(tag="att_loss", step=step, value=np.mean(evloss["att_loss"]))
writer.add_scalar(tag="rg_loss", step=step, value=np.mean(evloss["rg_loss"]))
writer.add_scalar(tag="cl_loss", step=step, value=np.mean(evloss["cl_loss"]))
writer.add_scalar(tag="total_loss", step=step, value=np.mean(evloss["total_loss"]))
if epoch % 29 == 0:
paddle.save(model.state_dict(),args.param_save+"{}".format(str(epoch)))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/oscarlin/chinese-casual-extraction.git
[email protected]:oscarlin/chinese-casual-extraction.git
oscarlin
chinese-casual-extraction
ChineseCasualExtraction
master

搜索帮助