代码拉取完成,页面将自动刷新
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict, Tuple
import scripts.cmn_eng as cmn_eng
from modules.common import constants
from modules.helper.model_helper import ModelFactory, ModelRunner
# 多头注意力的数量
HEAD_NUM = constants.HEAD_NUM_DEFAULT
# 编码器/解码器层数
LAYER_NUM = constants.LAYER_NUM_DEFAULT
# 词向量的维度
EMBEDDING_SIZE = constants.EMBEDDING_SIZE_DEFAULT
# 前馈全连接层的隐藏层维度
HIDDEN_SIZE = constants.HIDDEN_SIZE_DEFAULT
# 最小预测源序列长度
MIN_PREDICT_SOURCE_SEQUENCE_LENGTH = constants.MIN_PREDICT_SOURCE_SEQUENCE_LENGTH_DEFAULT
# 最大预测源序列长度
MAX_PREDICT_SOURCE_SEQUENCE_LENGTH = constants.MAX_PREDICT_SOURCE_SEQUENCE_LENGTH_DEFAULT
# 最大预测结果序列长度
MAX_PREDICT_RESULT_SEQUENCE_LENGTH = 30 #constants.MAX_PREDICT_RESULT_SEQUENCE_LENGTH_DEFAULT
# 预处理的缓存路径前缀
PREPROCESSION_CACHE_PATH_PREFIX = './resources/cache/'
# 预处理的缓存路径后缀
PREPROCESSION_CACHE_PATH_SUFFIX = '.cache'
# 语料文件的路径
CORPUS_DATA_PATH = './resources/corpus/cmn-eng.txt'
# 模型参数保存的路径
MODEL_DICT_PATH = './resources/model/translation-latest.pth'
# 测试字符串
TEST_LIST = ['你好', '等一下', '不会吧', '我赢了', '开始',
'我知道', '听着', '我退出', '冷静点', '住手']
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
corpus_dict_cache_path = PREPROCESSION_CACHE_PATH_PREFIX + 'corpus_dict' + PREPROCESSION_CACHE_PATH_SUFFIX
corpus_dataset_cache_path = PREPROCESSION_CACHE_PATH_PREFIX + 'corpus_dataset' + PREPROCESSION_CACHE_PATH_SUFFIX
# 加载缓存的预处理数据
corpus_dict = torch.load(corpus_dict_cache_path)
dataset = torch.load(corpus_dataset_cache_path)
model_factory = ModelFactory(max_sequence_length=MAX_PREDICT_SOURCE_SEQUENCE_LENGTH, pad_index=0)
# 创建模型
model = model_factory.create_model(head_num=HEAD_NUM, layer_num=LAYER_NUM,
embedding_size=EMBEDDING_SIZE, hidden_size=HIDDEN_SIZE,
source_vocab_size=len(corpus_dict.source_word2index_dict),
target_vocab_size=len(corpus_dict.target_word2index_dict))
# 加载模型参数
state_dict = torch.load(MODEL_DICT_PATH)
model.load_state_dict(state_dict)
# 打印翻译结果
def print_result(input_sequence, target_index, target_str):
if target_index == -1:
input_str = ''.join(input_sequence)
print(f'\n{input_str} => ', end='')
else:
print(target_str, end=' ')
model_runner = ModelRunner(model, corpus_dict=corpus_dict, device=device,
min_source_sequence_length=MIN_PREDICT_SOURCE_SEQUENCE_LENGTH,
max_source_sequence_length=MAX_PREDICT_SOURCE_SEQUENCE_LENGTH,
max_result_sequence_length=MAX_PREDICT_RESULT_SEQUENCE_LENGTH)
print('预置的测试字符串: ', end='')
for predict_str in TEST_LIST:
input_sequence = cmn_eng.tokenize_cmn(predict_str)
# 推理模型
model_runner.run_model(input_sequence=input_sequence, output_callback=print_result)
while True:
predict_str = input('\n\n输入待翻译的字符串(空行结束): ')
if len(predict_str) > 0:
input_sequence = cmn_eng.tokenize_cmn(predict_str)
# 推理模型
model_runner.run_model(input_sequence=input_sequence, output_callback=print_result)
else:
break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。