2 Star 2 Fork 1

cheneyxu/基于Transform的机器翻译系统

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_evaluate.py 1.85 KB
一键复制 编辑 原始数据 按行查看 历史
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: admin
@file: train_evaluate.py
@time: 2021/09/02
@desc:
"""
import time
import torch
from model import config
from model.loss import SimpleLossCompute
def run_epoch(data, model, loss_compute, epoch):
start = time.time()
total_tokens = 0.
total_loss = 0.
tokens = 0.
for i, batch in enumerate(data):
out = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
loss = loss_compute(out, batch.trg_y, batch.ntokens)
total_loss += loss
total_tokens += batch.ntokens
tokens += batch.ntokens
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch %d Batch: %d Loss: %f Tokens per Sec: %fs" % (
epoch, i - 1, loss / batch.ntokens, (tokens.float() / elapsed / 1000.)))
start = time.time()
tokens = 0
return total_loss / total_tokens
def train(data, model, criterion, optimizer):
"""
训练并保存模型
"""
# 初始化模型在dev集上的最优Loss为一个较大值
best_dev_loss = 1e5
for epoch in range(config.EPOCHS):
# 模型训练
model.train()
run_epoch(data.train_data, model, SimpleLossCompute(model.generator, criterion, optimizer), epoch)
model.eval()
# 在dev集上进行loss评估
print('>>>>> Evaluate')
dev_loss = run_epoch(data.dev_data, model, SimpleLossCompute(model.generator, criterion, None), epoch)
print('<<<<< Evaluate loss: %f' % dev_loss)
# 如果当前epoch的模型在dev集上的loss优于之前记录的最优loss则保存当前模型,并更新最优loss值
if dev_loss < best_dev_loss:
torch.save(model.state_dict(), config.SAVE_FILE)
best_dev_loss = dev_loss
print('****** Save model done... ******')
print()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/cheneyxym/Transformer_translate_en2ch.git
git@gitee.com:cheneyxym/Transformer_translate_en2ch.git
cheneyxym
Transformer_translate_en2ch
基于Transform的机器翻译系统
main

搜索帮助