1 Star 0 Fork 1

雲图/tacotron2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
loss_function.py 673 Bytes
一键复制 编辑 原始数据 按行查看 历史
Rafael Valle 提交于 2018-05-03 15:16 . adding python files
from torch import nn
class Tacotron2Loss(nn.Module):
def __init__(self):
super(Tacotron2Loss, self).__init__()
def forward(self, model_output, targets):
mel_target, gate_target = targets[0], targets[1]
mel_target.requires_grad = False
gate_target.requires_grad = False
gate_target = gate_target.view(-1, 1)
mel_out, mel_out_postnet, gate_out, _ = model_output
gate_out = gate_out.view(-1, 1)
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
nn.MSELoss()(mel_out_postnet, mel_target)
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
return mel_loss + gate_loss
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ytzy/tacotron2.git
[email protected]:ytzy/tacotron2.git
ytzy
tacotron2
tacotron2
master

搜索帮助