1 Star 0 Fork 1

元原子/gpt2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
多显卡数据加载器.py 1.59 KB
一键复制 编辑 原始数据 按行查看 历史
元原子 提交于 2024-11-22 15:49 . 添加了多显卡训练;
# pip install tiktoken 它提供了一种简单的方式来计数和分割文本为tokens。
import tiktoken
import torch
class 多显卡轻量数据加载器:
def __init__(self, , , 进程班号, 进程数量):
"""
这是一个自定义的数据加载器
:param 批:
:param 序:
"""
self. =
self. =
self.进程数量 = 进程数量
self.进程班号 = 进程班号
# 在初始化时,将字词从磁盘中载入到内存
with open("英文训练集.txt", 'r', encoding="utf8") as 文件:
文本 = 文件.read()
编码器 = tiktoken.get_encoding("gpt2")
字词 = 编码器.encode(文本)
self.字词 = torch.tensor(字词)
print(f"载入了 {len(self.字词)} 字词。")
print(f"每轮有 {len(self.字词) // ( * )} 批数据。")
self.当前位置 = self. * self. * self.进程班号
def 下一批(self):
"""
:return:
"""
= self.
= self.
缓存 = self.字词[self.当前位置:self.当前位置 + * + 1]
# 输入
x = 缓存[:-1].view(, )
# 目标
y = 缓存[1:].view(, )
# 在字词张量中前进的位置
self.当前位置 += * * self.进程数量
# 如果加载下一个批次超出范围,则重置
if self.当前位置 + ( * * self.进程数量 + 1) > len(self.字词):
self.当前位置 = self. * self. * self.进程班号
return x, y
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zozero/gpt2.git
[email protected]:zozero/gpt2.git
zozero
gpt2
gpt2
master

搜索帮助