代码拉取完成,页面将自动刷新
# --- Imports --- #
import torch
import torch.nn.functional as F
# --- Perceptual loss network --- #
class LossNetwork(torch.nn.Module):
def __init__(self, vgg_model):
super(LossNetwork, self).__init__()
self.vgg_layers = vgg_model
self.layer_name_mapping = {
'3': "relu1_2",
'8': "relu2_2",
'15': "relu3_3"
}
def output_features(self, x):
output = {}
for name, module in self.vgg_layers._modules.items():
x = module(x)
if name in self.layer_name_mapping:
output[self.layer_name_mapping[name]] = x
return list(output.values())
def forward(self, pred_im, gt):
loss = []
pred_im_features = self.output_features(pred_im)
gt_features = self.output_features(gt)
for pred_im_feature, gt_feature in zip(pred_im_features, gt_features):
loss.append(F.mse_loss(pred_im_feature, gt_feature))
return sum(loss)/len(loss)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。