代码拉取完成,页面将自动刷新
# pip install tiktoken 它提供了一种简单的方式来计数和分割文本为tokens。
import tiktoken
import torch
class 多显卡轻量数据加载器:
def __init__(self, 批, 序, 进程班号, 进程数量):
"""
这是一个自定义的数据加载器
:param 批:
:param 序:
"""
self.批 = 批
self.序 = 序
self.进程数量 = 进程数量
self.进程班号 = 进程班号
# 在初始化时,将字词从磁盘中载入到内存
with open("英文训练集.txt", 'r', encoding="utf8") as 文件:
文本 = 文件.read()
编码器 = tiktoken.get_encoding("gpt2")
字词 = 编码器.encode(文本)
self.字词 = torch.tensor(字词)
print(f"载入了 {len(self.字词)} 字词。")
print(f"每轮有 {len(self.字词) // (批 * 序)} 批数据。")
self.当前位置 = self.批 * self.序 * self.进程班号
def 下一批(self):
"""
:return:
"""
批 = self.批
序 = self.序
缓存 = self.字词[self.当前位置:self.当前位置 + 批 * 序 + 1]
# 输入
x = 缓存[:-1].view(批, 序)
# 目标
y = 缓存[1:].view(批, 序)
# 在字词张量中前进的位置
self.当前位置 += 批 * 序 * self.进程数量
# 如果加载下一个批次超出范围,则重置
if self.当前位置 + (批 * 序 * self.进程数量 + 1) > len(self.字词):
self.当前位置 = self.批 * self.序 * self.进程班号
return x, y
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。