代码拉取完成,页面将自动刷新
同步操作将从 Bin/多模态医学诊断 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
import torch.nn as nn
import torch.nn.functional as F
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loss function
def KL(alpha, c):
beta = torch.ones((1, c)).to(device)
S_alpha = torch.sum(alpha, dim=1, keepdim=True)
S_beta = torch.sum(beta, dim=1, keepdim=True)
lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)
dg0 = torch.digamma(S_alpha)
dg1 = torch.digamma(alpha)
kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
return kl
def ce_loss(p, alpha, c, global_step, annealing_step):
S = torch.sum(alpha, dim=1, keepdim=True)
E = alpha - 1
label = F.one_hot(p, num_classes=c)
A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
annealing_coef = min(1, global_step / annealing_step)
alp = E * (1 - label) + 1
B = annealing_coef * KL(alp, c)
return (A + B)
class TMC(nn.Module):
def __init__(self, classes, views, classifier_dims, lambda_epochs=1):
"""
:param classes: Number of classification categories
:param views: Number of views
:param classifier_dims: Dimension of the classifier
:param annealing_epoch: KL divergence annealing epoch during training
"""
super(TMC, self).__init__()
self.views = views
self.classes = classes
self.lambda_epochs = lambda_epochs
self.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.views)])
def DS_Combin(self, alpha):
"""
:param alpha: All Dirichlet distribution parameters.
:return: Combined Dirichlet distribution parameters.
"""
def DS_Combin_two(alpha1, alpha2):
"""
:param alpha1: Dirichlet distribution parameters of view 1
:param alpha2: Dirichlet distribution parameters of view 2
:return: Combined Dirichlet distribution parameters
"""
alpha = dict()
alpha[0], alpha[1] = alpha1, alpha2
b, S, E, u = dict(), dict(), dict(), dict()
for v in range(2):
S[v] = torch.sum(alpha[v], dim=1, keepdim=True)
E[v] = alpha[v]-1
b[v] = E[v]/(S[v].expand(E[v].shape))
u[v] = self.classes/S[v]
# b^0 @ b^(0+1)
bb = torch.bmm(b[0].view(-1, self.classes, 1), b[1].view(-1, 1, self.classes))
# b^0 * u^1
uv1_expand = u[1].expand(b[0].shape)
bu = torch.mul(b[0], uv1_expand)
# b^1 * u^0
uv_expand = u[0].expand(b[0].shape)
ub = torch.mul(b[1], uv_expand)
# calculate C
bb_sum = torch.sum(bb, dim=(1, 2), out=None)
bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1)
C = bb_sum - bb_diag
# calculate b^a
b_a = (torch.mul(b[0], b[1]) + bu + ub)/((1-C).view(-1, 1).expand(b[0].shape))
# calculate u^a
u_a = torch.mul(u[0], u[1])/((1-C).view(-1, 1).expand(u[0].shape))
# calculate new S
S_a = self.classes / u_a
# calculate new e_k
e_a = torch.mul(b_a, S_a.expand(b_a.shape))
alpha_a = e_a + 1
return alpha_a
for v in range(len(alpha)-1):
if v==0:
alpha_a = DS_Combin_two(alpha[0], alpha[1])
else:
alpha_a = DS_Combin_two(alpha_a, alpha[v+1])
return alpha_a
def forward(self, X, y, global_step):
evidence = self.infer(X)
loss = 0
alpha = dict()
for v_num in range(len(X)):
alpha[v_num] = evidence[v_num] + 1
loss += ce_loss(y, alpha[v_num], self.classes, global_step, self.lambda_epochs)
alpha_a = self.DS_Combin(alpha)
evidence_a = alpha_a - 1
loss += ce_loss(y, alpha_a, self.classes, global_step, self.lambda_epochs)
loss = torch.mean(loss)
return evidence, evidence_a, loss
def infer(self, input):
"""
:param input: Multi-view data
:return: evidence of every view
"""
evidence = dict()
for v_num in range(self.views):
evidence[v_num] = self.Classifiers[v_num](input[v_num])
return evidence
class Classifier(nn.Module):
def __init__(self, classifier_dims, classes):
super(Classifier, self).__init__()
self.num_layers = len(classifier_dims)
self.fc = nn.ModuleList()
for i in range(self.num_layers-1):
self.fc.append(nn.Linear(classifier_dims[i], classifier_dims[i+1]))
self.fc.append(nn.Linear(classifier_dims[self.num_layers-1], classes))
self.fc.append(nn.Softplus())
def forward(self, x):
h = self.fc[0](x)
for i in range(1, len(self.fc)):
h = self.fc[i](h)
return h
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。