1 Star 0 Fork 1

MagiCodeX/transformer-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
translation_predict.py 3.47 KB
一键复制 编辑 原始数据 按行查看 历史
MagiCodeX 提交于 2024-06-20 18:23 . 修改注释上的错误
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict, Tuple
import scripts.cmn_eng as cmn_eng
from modules.common import constants
from modules.helper.model_helper import ModelFactory, ModelRunner
# 多头注意力的数量
HEAD_NUM = constants.HEAD_NUM_DEFAULT
# 编码器/解码器层数
LAYER_NUM = constants.LAYER_NUM_DEFAULT
# 词向量的维度
EMBEDDING_SIZE = constants.EMBEDDING_SIZE_DEFAULT
# 前馈全连接层的隐藏层维度
HIDDEN_SIZE = constants.HIDDEN_SIZE_DEFAULT
# 最小预测源序列长度
MIN_PREDICT_SOURCE_SEQUENCE_LENGTH = constants.MIN_PREDICT_SOURCE_SEQUENCE_LENGTH_DEFAULT
# 最大预测源序列长度
MAX_PREDICT_SOURCE_SEQUENCE_LENGTH = constants.MAX_PREDICT_SOURCE_SEQUENCE_LENGTH_DEFAULT
# 最大预测结果序列长度
MAX_PREDICT_RESULT_SEQUENCE_LENGTH = 30 #constants.MAX_PREDICT_RESULT_SEQUENCE_LENGTH_DEFAULT
# 预处理的缓存路径前缀
PREPROCESSION_CACHE_PATH_PREFIX = './resources/cache/'
# 预处理的缓存路径后缀
PREPROCESSION_CACHE_PATH_SUFFIX = '.cache'
# 语料文件的路径
CORPUS_DATA_PATH = './resources/corpus/cmn-eng.txt'
# 模型参数保存的路径
MODEL_DICT_PATH = './resources/model/translation-latest.pth'
# 测试字符串
TEST_LIST = ['你好', '等一下', '不会吧', '我赢了', '开始',
'我知道', '听着', '我退出', '冷静点', '住手']
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
corpus_dict_cache_path = PREPROCESSION_CACHE_PATH_PREFIX + 'corpus_dict' + PREPROCESSION_CACHE_PATH_SUFFIX
corpus_dataset_cache_path = PREPROCESSION_CACHE_PATH_PREFIX + 'corpus_dataset' + PREPROCESSION_CACHE_PATH_SUFFIX
# 加载缓存的预处理数据
corpus_dict = torch.load(corpus_dict_cache_path)
dataset = torch.load(corpus_dataset_cache_path)
model_factory = ModelFactory(max_sequence_length=MAX_PREDICT_SOURCE_SEQUENCE_LENGTH, pad_index=0)
# 创建模型
model = model_factory.create_model(head_num=HEAD_NUM, layer_num=LAYER_NUM,
embedding_size=EMBEDDING_SIZE, hidden_size=HIDDEN_SIZE,
source_vocab_size=len(corpus_dict.source_word2index_dict),
target_vocab_size=len(corpus_dict.target_word2index_dict))
# 加载模型参数
state_dict = torch.load(MODEL_DICT_PATH)
model.load_state_dict(state_dict)
# 打印翻译结果
def print_result(input_sequence, target_index, target_str):
if target_index == -1:
input_str = ''.join(input_sequence)
print(f'\n{input_str} => ', end='')
else:
print(target_str, end=' ')
model_runner = ModelRunner(model, corpus_dict=corpus_dict, device=device,
min_source_sequence_length=MIN_PREDICT_SOURCE_SEQUENCE_LENGTH,
max_source_sequence_length=MAX_PREDICT_SOURCE_SEQUENCE_LENGTH,
max_result_sequence_length=MAX_PREDICT_RESULT_SEQUENCE_LENGTH)
print('预置的测试字符串: ', end='')
for predict_str in TEST_LIST:
input_sequence = cmn_eng.tokenize_cmn(predict_str)
# 推理模型
model_runner.run_model(input_sequence=input_sequence, output_callback=print_result)
while True:
predict_str = input('\n\n输入待翻译的字符串(空行结束): ')
if len(predict_str) > 0:
input_sequence = cmn_eng.tokenize_cmn(predict_str)
# 推理模型
model_runner.run_model(input_sequence=input_sequence, output_callback=print_result)
else:
break
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/magicodex/transformer-pytorch.git
[email protected]:magicodex/transformer-pytorch.git
magicodex
transformer-pytorch
transformer-pytorch
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385