1 Star 0 Fork 1

yanlang0123/Pytorch-UNet

forked from xijunjun/Pytorch-UNet 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
myloss.py 1.26 KB
一键复制 编辑 原始数据 按行查看 历史
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch
from torch.autograd import Function, Variable
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2 * self.inter.float() / self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ self.union * self.union
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
else:
s = Variable(torch.FloatTensor(1).zero_())
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yanlang0123/Pytorch-UNet.git
[email protected]:yanlang0123/Pytorch-UNet.git
yanlang0123
Pytorch-UNet
Pytorch-UNet
master

搜索帮助