代码拉取完成,页面将自动刷新
#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: admin
@file: train_evaluate.py
@time: 2021/09/02
@desc:
"""
import time
import torch
from model import config
from model.loss import SimpleLossCompute
def run_epoch(data, model, loss_compute, epoch):
start = time.time()
total_tokens = 0.
total_loss = 0.
tokens = 0.
for i, batch in enumerate(data):
out = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
loss = loss_compute(out, batch.trg_y, batch.ntokens)
total_loss += loss
total_tokens += batch.ntokens
tokens += batch.ntokens
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch %d Batch: %d Loss: %f Tokens per Sec: %fs" % (
epoch, i - 1, loss / batch.ntokens, (tokens.float() / elapsed / 1000.)))
start = time.time()
tokens = 0
return total_loss / total_tokens
def train(data, model, criterion, optimizer):
"""
训练并保存模型
"""
# 初始化模型在dev集上的最优Loss为一个较大值
best_dev_loss = 1e5
for epoch in range(config.EPOCHS):
# 模型训练
model.train()
run_epoch(data.train_data, model, SimpleLossCompute(model.generator, criterion, optimizer), epoch)
model.eval()
# 在dev集上进行loss评估
print('>>>>> Evaluate')
dev_loss = run_epoch(data.dev_data, model, SimpleLossCompute(model.generator, criterion, None), epoch)
print('<<<<< Evaluate loss: %f' % dev_loss)
# 如果当前epoch的模型在dev集上的loss优于之前记录的最优loss则保存当前模型,并更新最优loss值
if dev_loss < best_dev_loss:
torch.save(model.state_dict(), config.SAVE_FILE)
best_dev_loss = dev_loss
print('****** Save model done... ******')
print()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。