代码拉取完成,页面将自动刷新
同步操作将从 东方佑/masr 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
import torch.nn as nn
import data
from models.conv import GatedConv
from tqdm import tqdm
from decoder import GreedyDecoder
from warpctc_pytorch import CTCLoss
import tensorboardX as tensorboard
import torch.nn.functional as F
import json
def train(
model,
epochs=1000,
batch_size=8,
train_index_path="data_aishell/train-sort.manifest",
dev_index_path="data_aishell/dev.manifest",
labels_path="data_aishell/labels.json",
learning_rate=0.3,
momentum=0.5,
max_grad_norm=0.1,
weight_decay=0,
):
train_path = "/home/chenyang/PycharmProjects/common_voice/clips1"
train_path1 = "/home/chenyang/PycharmProjects/common_voice/clips2"
label_path = "/home/chenyang/PycharmProjects/common_voice/common_voice_chinese_data/train.tsv"
train_dataset = data.MASRDataset([train_path, train_path1], label_path)
# train_dataset = data.MASRDataset(train_index_path, labels_path)
batchs = (len(train_dataset) + batch_size - 1) // batch_size
label_path = "/home/chenyang/PycharmProjects/common_voice/common_voice_chinese_data/test.tsv"
dev_dataset = data.MASRDataset([train_path, train_path1], label_path)
# dev_dataset = data.MASRDataset(dev_index_path, labels_path)
train_dataloader = data.MASRDataLoader(
train_dataset, batch_size=batch_size, num_workers=8
)
train_dataloader_shuffle = data.MASRDataLoader(
train_dataset, batch_size=batch_size, num_workers=8, shuffle=True
)
dev_dataloader = data.MASRDataLoader(
dev_dataset, batch_size=batch_size, num_workers=8
)
parameters = model.parameters()
optimizer = torch.optim.SGD(
parameters,
lr=learning_rate,
momentum=momentum,
nesterov=True,
weight_decay=weight_decay,
)
ctcloss = CTCLoss(size_average=True)
# lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
writer = tensorboard.SummaryWriter()
gstep = 0
for epoch in range(epochs):
epoch_loss = 0
if epoch > 0:
train_dataloader = train_dataloader_shuffle
# lr_sched.step()
lr = get_lr(optimizer)
writer.add_scalar("lr/epoch", lr, epoch)
for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
x = x.to("cuda")
out, out_lens = model(x, x_lens)
out = out.transpose(0, 1).transpose(0, 2)
loss = ctcloss(out, y, out_lens, y_lens)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
epoch_loss += loss.item()
writer.add_scalar("loss/step", loss.item(), gstep)
gstep += 1
print("[{}/{}][{}/{}]\tLoss = {}".format(epoch + 1, epochs, i, int(batchs), loss.item()))
epoch_loss = epoch_loss / batchs
cer = eval(model, dev_dataloader)
writer.add_scalar("loss/epoch", epoch_loss, epoch)
writer.add_scalar("cer/epoch", cer, epoch)
print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
torch.save(model, "pretrained/model_{}.pth".format(epoch))
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
def eval(model, dataloader):
model.eval()
decoder = GreedyDecoder(dataloader.dataset.word_to_ix)
cer = 0
print("decoding")
with torch.no_grad():
for i, (x, y, x_lens, y_lens) in tqdm(enumerate(dataloader)):
x = x.to("cuda")
outs, out_lens = model(x, x_lens)
outs = F.softmax(outs, 1)
outs = outs.transpose(1, 2)
ys = []
offset = 0
for y_len in y_lens:
ys.append(y[offset: offset + y_len])
offset += y_len
out_strings, out_offsets = decoder.decode(outs, out_lens)
y_strings = decoder.convert_to_strings(ys)
for pred, truth in zip(out_strings, y_strings):
trans, ref = pred[0], truth[0]
cer += decoder.cer(trans, ref) / float(len(ref))
cer /= len(dataloader.dataset)
model.train()
return cer
if __name__ == "__main__":
with open("/home/chenyang/PycharmProjects/masr/data_aishell/vocabs") as f:
vocabulary = f.read()
vocabulary = "".join(vocabulary.split("\n"))
model = GatedConv(vocabulary)
model.to_train()
model.to("cuda")
train(model)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。