3 Star 4 Fork 0

Gitee 极速下载/Syn2Real

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/rajeevyasarla/Syn2Real
克隆/下载
perceptual.py 1012 Bytes
一键复制 编辑 原始数据 按行查看 历史
rajeevyasarla 提交于 2020-06-09 19:02 . syn2real files
# --- Imports --- #
import torch
import torch.nn.functional as F
# --- Perceptual loss network --- #
class LossNetwork(torch.nn.Module):
def __init__(self, vgg_model):
super(LossNetwork, self).__init__()
self.vgg_layers = vgg_model
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3"
}
def output_features(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return list(output.values())
def forward(self, pred_im, gt):
loss = []
pred_im_features = self.output_features(pred_im)
gt_features = self.output_features(gt)
for pred_im_feature, gt_feature in zip(pred_im_features, gt_features):
loss.append(F.mse_loss(pred_im_feature, gt_feature))
return sum(loss)/len(loss)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/Syn2Real.git
[email protected]:mirrors/Syn2Real.git
mirrors
Syn2Real
Syn2Real
master

搜索帮助