代码拉取完成,页面将自动刷新
import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from monai.utils import pytorch_after, look_up_option
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
#from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
#from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
from torch import Tensor
import pytorch_ssim
import models.sam.utils.transforms as samtrans
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from collections.abc import Callable, Sequence
from monai.transforms import (
AsDiscrete,
)
import torch
class MultBCEWithLogitsLoss(_Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
pos_weight: Optional[Tensor] = None) -> None:
super(MultBCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor]
def forward(self, input2: Tensor, input3: Tensor, target2: Tensor, target3: Tensor) -> Tensor:
# loss_vess = F.binary_cross_entropy_with_logits(input1, target1,
# self.weight,
# pos_weight=self.pos_weight,
# reduction=self.reduction)
loss_art = F.binary_cross_entropy_with_logits(input2, target2,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss_ven = F.binary_cross_entropy_with_logits(input3, target3,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
vessel_gt = torch.max(target2, target3)
vessel_pred = torch.max(input2, input3)
loss_vess = F.binary_cross_entropy_with_logits(vessel_pred, vessel_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
inter_gt = torch.min(target2, target3)
inter_art = torch.min(input2, target3)
inter_ven = torch.min(target2, input3)
loss_inter_art = F.binary_cross_entropy_with_logits(inter_art, inter_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss_inter_ven = F.binary_cross_entropy_with_logits(inter_ven, inter_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss = loss_art + loss_ven + 0.5 * loss_inter_art + 0.5 * loss_inter_ven + loss_vess
# loss = loss_art + loss_ven
return loss
pos_weight = torch.ones([1]).to(device)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mult_criterion = MultBCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1, 11, (args.b, 7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 2, device=None):
hard = 0
epoch_loss = 0
ind = 0
vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') # 损失函数将 Dice Loss(一种用于分割任务的评价指标)和交叉熵损失结合起来
else:
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype = torch.float32, device = device)
masks = pack['label'].to(dtype = torch.float32, device = device)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
print('Creating-----------------')
# pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
# print(imgs.shape)
# print(pt.shape)
imgs = imgs.repeat(1, 3, 1, 1)
# point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
# c, w, h = imgs.size()
longsize = w if w >= h else h
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
imge = net.module.image_encoder(imgs)
with torch.no_grad():
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
if rank == 0:
vis = vis * 5
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss/n_tra, tuple([iou/n_tra, dice/n_tra])
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
rater_res = [(0, 0, 0, 0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
masksw = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:, :, buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[..., buoy:buoy + evl_ch]
masks = masksw[..., buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1, 3, 1, 1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
# c, w, h = imgs.size()
longsize = w if w >= h else h
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
# true_mask_ave = cons_tensor(true_mask_ave)
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
else:
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
else:
# print('se', se)
# print('de', de)
# print('image_pe', net.prompt_encoder.get_dense_pe())
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
if args.distributed != 'none':
tot += reduce_value(lossfunc(pred, masks), world_size, average=True)
else:
tot += lossfunc(pred, masks)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot / n_val, tuple([iou / n_val, dice / n_val]),
def train_sam_lite(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis=2, device=None, fake_prompt=None):
hard = 0
ind = 0
vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True,
reduction='mean') # 损失函数将 Dice Loss(一种用于分割任务的评价指标)和交叉熵损失结合起来
else:
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
imge = net.module.image_encoder(imgs)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
if rank == 0:
vis = vis * 5
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
namecat + args.dataset + 'epoch+' + str(
epoch) + '.jpg'), reverse=False)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss / n_tra, tuple([iou / n_tra, dice / n_tra])
def validation_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
masksw = pack['label'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
imgs = imgsw[..., buoy:buoy + evl_ch]
masks = masksw[..., buoy:buoy + evl_ch]
buoy += evl_ch
mask_type = torch.float32
ind += 1
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
else:
# print('se', se)
# print('de', de)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
if args.distributed != 'none':
tot += reduce_value(lossfunc(pred, masks), world_size, average=True)
else:
tot += lossfunc(pred, masks)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot / n_val, tuple([iou / n_val, dice / n_val]),
def seg_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
imgs = imgsw[..., buoy:buoy + evl_ch]
buoy += evl_ch
mask_type = torch.float32
ind += 1
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return 'okkkkkkkkkkkkkkk',
def seg_mult_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgs = pack['image'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return 'okkkkkkkkkkkkkkk',
def train_mult_sam_lite(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis=2, device=None, fake_prompt=None):
hard = 0
ind = 0
# vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
# ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
# threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
lossfunc = mult_criterion # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
mask_vessel = pack['mask_vessel'].to(dtype=torch.float32, device=device)
mask_arteriole = pack['mask_arteriole'].to(dtype=torch.float32, device=device)
mask_venule = pack['mask_venule'].to(dtype=torch.float32, device=device)
if args.net == 'sam_self_with_prompt':
pt = pack['pt']
point_labels = pack['p_label']
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
# pt = (coords_torch, labels_torch)
# print('pt', pt)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
if args.distributed != 'none':
if args.net == 'sam_self':
imgs = net.module.image_fusion(imgs)
imge = net.module.image_encoder(imgs)
else:
if args.net == 'sam_self':
imgs = net.image_fusion(imgs)
imge = net.image_encoder(imgs)
if args.distributed != 'none':
if args.net == 'sam_self_with_prompt':
with torch.no_grad():
se, de = net.module.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
else:
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
# loss = lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule)
loss = lossfunc(pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_arteriole,
mask_venule)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
if args.distributed != 'none':
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
# if rank == 0:
# vis = vis * 5
# if ind % vis == 0:
# namecat = 'Train'
# for na in name:
# namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
# namecat + args.dataset + 'epoch+' + str(
# epoch) + '.jpg'), reverse=False)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
# tempvessel = eval_seg(pred[0:1, :, :], mask_vessel, threshold)
# mix_resvessel = tuple([sum(a) for a in zip(mix_resvessel, tempvessel)])
# metrics_tensorvessel = torch.tensor(mix_resvessel, dtype=torch.float32, device=device)
# metrics_tensorvessel = reduce_value(metrics_tensorvessel, world_size, average=True)
# iou_vessel, dice_vessel = metrics_tensorvessel.tolist()
#
# temparteriole = eval_seg(pred[1:2, :, :], mask_arteriole, threshold)
# mix_resarteriole = tuple([sum(a) for a in zip(mix_resarteriole, temparteriole)])
# metrics_tensorarteriole = torch.tensor(mix_resarteriole, dtype=torch.float32, device=device)
# metrics_tensorarteriole = reduce_value(metrics_tensorarteriole, world_size, average=True)
# iou_arteriole, dice_arteriole = metrics_tensorarteriole.tolist()
#
# tempvenule = eval_seg(pred[2:3, :, :], mask_venule, threshold)
# mix_resvenule = tuple([sum(a) for a in zip(mix_res, tempvenule)])
# metrics_tensorvenule = torch.tensor(mix_resvenule, dtype=torch.float32, device=device)
# metrics_tensorvenule = reduce_value(metrics_tensorvenule, world_size, average=True)
# iou_venule, dice_venule = metrics_tensorvenule.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss / n_tra
def validation_mult_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
mix_resvessel = (0, 0, 0, 0)
mix_resvenule = (0, 0, 0, 0)
mix_resarteriole = (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
lossfunc = mult_criterion
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
mask_vessel = pack['mask_vessel'].to(dtype=torch.float32, device=device)
mask_arteriole = pack['mask_arteriole'].to(dtype=torch.float32, device=device)
mask_venule = pack['mask_venule'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
if args.net == 'sam_self_with_prompt':
pt = pack['pt']
point_labels = pack['p_label']
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
buoy = 0
mask_type = torch.float32
ind += 1
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgsw = imgsw.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
if args.net == 'sam_self' or args.net == "'sam_self_with_prompt'":
imgsw = net.module.image_fusion(imgsw)
imge = net.module.image_encoder(imgsw)
else:
if args.net == 'sam_self' or args.net == "'sam_self_with_prompt'":
# print(imgsw.shape)
imgsw = net.image_fusion(imgsw)
# print('testing')
imge = net.image_encoder(imgsw)
if args.distributed != 'none':
if args.net == 'sam_self_with_prompt':
se, de = net.module.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
else:
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
if args.net == 'sam_self_with_prompt':
se, de = net.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
loss = lossfunc(pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_arteriole,
mask_venule)
if args.distributed != 'none':
# tot += reduce_value(lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule), world_size, average=True)
tot += reduce_value(loss, world_size, average=True)
else:
# tot += lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule)
tot += loss
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
tempvessel = eval_seg(pred[:, 0:1, :, :], mask_vessel, threshold)
mix_resvessel = tuple([sum(a) for a in zip(mix_resvessel, tempvessel)])
metrics_tensorvessel = torch.tensor(mix_resvessel, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorvessel = reduce_value(metrics_tensorvessel, world_size, average=True)
iou_vessel, dice_vessel = metrics_tensorvessel.tolist()
temparteriole = eval_seg(pred[:, 1:2, :, :], mask_arteriole, threshold)
mix_resarteriole = tuple([sum(a) for a in zip(mix_resarteriole, temparteriole)])
metrics_tensorarteriole = torch.tensor(mix_resarteriole, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorarteriole = reduce_value(metrics_tensorarteriole, world_size, average=True)
iou_arteriole, dice_arteriole = metrics_tensorarteriole.tolist()
tempvenule = eval_seg(pred[:, 2:3, :, :], mask_venule, threshold)
mix_resvenule = tuple([sum(a) for a in zip(mix_resvenule, tempvenule)])
metrics_tensorvenule = torch.tensor(mix_resvenule, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorvenule = reduce_value(metrics_tensorvenule, world_size, average=True)
iou_venule, dice_venule = metrics_tensorvenule.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return tot / n_val, tuple([iou_vessel / n_val, dice_vessel / n_val]), tuple([iou_arteriole / n_val, dice_arteriole / n_val]), tuple([iou_venule / n_val, dice_venule / n_val])
def seg_sam(args, net: nn.Module, train_loader, device=None):
hard = 0
epoch_loss = 0
net.eval()
n_tra = len(train_loader)
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
with torch.no_grad():
imge = net.image_encoder(imgs)
# imge= net.image_encoder(imgs)
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
mask = masks[i]
namecat = name[i].split('/')[-1].split('.')[0]
iter_mask(pred_single, mask,
os.path.join(args.path_helper['sample_path'], namecat + '.png'))
pbar.update()
return
def reduce_value(value, world_size, average=True):
if world_size < 2: # 单GPU的情况
return value
with torch.no_grad():
dist.all_reduce(value) # 对不同设备之间的value求和
if average: # 如果需要求平均,获得多块GPU计算loss的均值
value /= world_size
return value
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。