1 Star 0 Fork 0

desperadoxhy/SAM-Adapt

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
function.py 48.58 KB
一键复制 编辑 原始数据 按行查看 历史
desperadoxhy 提交于 2023-11-29 16:01 . add mean
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from monai.utils import pytorch_after, look_up_option
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
#from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
#from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
from torch import Tensor
import pytorch_ssim
import models.sam.utils.transforms as samtrans
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from collections.abc import Callable, Sequence
from monai.transforms import (
AsDiscrete,
)
import torch
class MultBCEWithLogitsLoss(_Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
pos_weight: Optional[Tensor] = None) -> None:
super(MultBCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor]
def forward(self, input2: Tensor, input3: Tensor, target2: Tensor, target3: Tensor) -> Tensor:
# loss_vess = F.binary_cross_entropy_with_logits(input1, target1,
# self.weight,
# pos_weight=self.pos_weight,
# reduction=self.reduction)
loss_art = F.binary_cross_entropy_with_logits(input2, target2,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss_ven = F.binary_cross_entropy_with_logits(input3, target3,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
vessel_gt = torch.max(target2, target3)
vessel_pred = torch.max(input2, input3)
loss_vess = F.binary_cross_entropy_with_logits(vessel_pred, vessel_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
inter_gt = torch.min(target2, target3)
inter_art = torch.min(input2, target3)
inter_ven = torch.min(target2, input3)
loss_inter_art = F.binary_cross_entropy_with_logits(inter_art, inter_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss_inter_ven = F.binary_cross_entropy_with_logits(inter_ven, inter_gt,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
loss = loss_art + loss_ven + 0.5 * loss_inter_art + 0.5 * loss_inter_ven + loss_vess
# loss = loss_art + loss_ven
return loss
pos_weight = torch.ones([1]).to(device)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mult_criterion = MultBCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1, 11, (args.b, 7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 2, device=None):
hard = 0
epoch_loss = 0
ind = 0
vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') # 损失函数将 Dice Loss(一种用于分割任务的评价指标)和交叉熵损失结合起来
else:
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype = torch.float32, device = device)
masks = pack['label'].to(dtype = torch.float32, device = device)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
print('Creating-----------------')
# pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
# print(imgs.shape)
# print(pt.shape)
imgs = imgs.repeat(1, 3, 1, 1)
# point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
# c, w, h = imgs.size()
longsize = w if w >= h else h
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
imge = net.module.image_encoder(imgs)
with torch.no_grad():
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
if rank == 0:
vis = vis * 5
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss/n_tra, tuple([iou/n_tra, dice/n_tra])
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
rater_res = [(0, 0, 0, 0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
masksw = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:, :, buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[..., buoy:buoy + evl_ch]
masks = masksw[..., buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1, 3, 1, 1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
# c, w, h = imgs.size()
longsize = w if w >= h else h
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
# true_mask_ave = cons_tensor(true_mask_ave)
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
else:
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
else:
# print('se', se)
# print('de', de)
# print('image_pe', net.prompt_encoder.get_dense_pe())
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
if args.distributed != 'none':
tot += reduce_value(lossfunc(pred, masks), world_size, average=True)
else:
tot += lossfunc(pred, masks)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot / n_val, tuple([iou / n_val, dice / n_val]),
def train_sam_lite(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis=2, device=None, fake_prompt=None):
hard = 0
ind = 0
vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True,
reduction='mean') # 损失函数将 Dice Loss(一种用于分割任务的评价指标)和交叉熵损失结合起来
else:
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
imge = net.module.image_encoder(imgs)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
if rank == 0:
vis = vis * 5
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
namecat + args.dataset + 'epoch+' + str(
epoch) + '.jpg'), reverse=False)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss / n_tra, tuple([iou / n_tra, dice / n_tra])
def validation_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
masksw = pack['label'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
imgs = imgsw[..., buoy:buoy + evl_ch]
masks = masksw[..., buoy:buoy + evl_ch]
buoy += evl_ch
mask_type = torch.float32
ind += 1
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
else:
# print('se', se)
# print('de', de)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
if args.distributed != 'none':
tot += reduce_value(lossfunc(pred, masks), world_size, average=True)
else:
tot += lossfunc(pred, masks)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot / n_val, tuple([iou / n_val, dice / n_val]),
def seg_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
imgs = imgsw[..., buoy:buoy + evl_ch]
buoy += evl_ch
mask_type = torch.float32
ind += 1
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=False,
)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return 'okkkkkkkkkkkkkkk',
def seg_mult_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgs = pack['image'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return 'okkkkkkkkkkkkkkk',
def train_mult_sam_lite(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis=2, device=None, fake_prompt=None):
hard = 0
ind = 0
# vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
# ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
# threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
lossfunc = mult_criterion # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
mask_vessel = pack['mask_vessel'].to(dtype=torch.float32, device=device)
mask_arteriole = pack['mask_arteriole'].to(dtype=torch.float32, device=device)
mask_venule = pack['mask_venule'].to(dtype=torch.float32, device=device)
if args.net == 'sam_self_with_prompt':
pt = pack['pt']
point_labels = pack['p_label']
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
# pt = (coords_torch, labels_torch)
# print('pt', pt)
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
ind += 1
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
if args.distributed != 'none':
if args.net == 'sam_self':
imgs = net.module.image_fusion(imgs)
imge = net.module.image_encoder(imgs)
else:
if args.net == 'sam_self':
imgs = net.image_fusion(imgs)
imge = net.image_encoder(imgs)
if args.distributed != 'none':
if args.net == 'sam_self_with_prompt':
with torch.no_grad():
se, de = net.module.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
else:
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
# loss = lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule)
loss = lossfunc(pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_arteriole,
mask_venule)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
if args.distributed != 'none':
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
# if rank == 0:
# vis = vis * 5
# if ind % vis == 0:
# namecat = 'Train'
# for na in name:
# namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
# namecat + args.dataset + 'epoch+' + str(
# epoch) + '.jpg'), reverse=False)
# 计算 iou 和dice ,并把所有gpu上的叠加求平均
# tempvessel = eval_seg(pred[0:1, :, :], mask_vessel, threshold)
# mix_resvessel = tuple([sum(a) for a in zip(mix_resvessel, tempvessel)])
# metrics_tensorvessel = torch.tensor(mix_resvessel, dtype=torch.float32, device=device)
# metrics_tensorvessel = reduce_value(metrics_tensorvessel, world_size, average=True)
# iou_vessel, dice_vessel = metrics_tensorvessel.tolist()
#
# temparteriole = eval_seg(pred[1:2, :, :], mask_arteriole, threshold)
# mix_resarteriole = tuple([sum(a) for a in zip(mix_resarteriole, temparteriole)])
# metrics_tensorarteriole = torch.tensor(mix_resarteriole, dtype=torch.float32, device=device)
# metrics_tensorarteriole = reduce_value(metrics_tensorarteriole, world_size, average=True)
# iou_arteriole, dice_arteriole = metrics_tensorarteriole.tolist()
#
# tempvenule = eval_seg(pred[2:3, :, :], mask_venule, threshold)
# mix_resvenule = tuple([sum(a) for a in zip(mix_res, tempvenule)])
# metrics_tensorvenule = torch.tensor(mix_resvenule, dtype=torch.float32, device=device)
# metrics_tensorvenule = reduce_value(metrics_tensorvenule, world_size, average=True)
# iou_venule, dice_venule = metrics_tensorvenule.tolist()
pbar.update()
# print('mix_res', mix_res)
# print('iou and dice', tuple([a/n_tra for a in mix_res]))
return epoch_loss / n_tra
def validation_mult_sam_lite(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False, fake_prompt=None):
# eval mode
net.eval()
n_val = len(val_loader) # the number of batch
mix_resvessel = (0, 0, 0, 0)
mix_resvenule = (0, 0, 0, 0)
mix_resarteriole = (0, 0, 0, 0)
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
lossfunc = mult_criterion
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
mask_vessel = pack['mask_vessel'].to(dtype=torch.float32, device=device)
mask_arteriole = pack['mask_arteriole'].to(dtype=torch.float32, device=device)
mask_venule = pack['mask_venule'].to(dtype=torch.float32, device=device)
name = pack['image_meta_dict']['filename_or_obj']
if args.net == 'sam_self_with_prompt':
pt = pack['pt']
point_labels = pack['p_label']
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
buoy = 0
mask_type = torch.float32
ind += 1
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgsw = imgsw.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
if args.net == 'sam_self' or args.net == "'sam_self_with_prompt'":
imgsw = net.module.image_fusion(imgsw)
imge = net.module.image_encoder(imgsw)
else:
if args.net == 'sam_self' or args.net == "'sam_self_with_prompt'":
# print(imgsw.shape)
imgsw = net.image_fusion(imgsw)
# print('testing')
imge = net.image_encoder(imgsw)
if args.distributed != 'none':
if args.net == 'sam_self_with_prompt':
se, de = net.module.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
else:
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
else:
if args.net == 'sam_self_with_prompt':
se, de = net.prompt_encoder(
# points=pt,
points=None,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=True,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=torch.from_numpy(fake_prompt['image_pe']).to(device),
sparse_prompt_embeddings=torch.from_numpy(fake_prompt['se']).to(device),
dense_prompt_embeddings=torch.from_numpy(fake_prompt['de']).to(device),
multimask_output=True,
)
loss = lossfunc(pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_arteriole,
mask_venule)
if args.distributed != 'none':
# tot += reduce_value(lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule), world_size, average=True)
tot += reduce_value(loss, world_size, average=True)
else:
# tot += lossfunc(pred[:, 0:1, :, :], pred[:, 1:2, :, :], pred[:, 2:3, :, :], mask_vessel, mask_arteriole, mask_venule)
tot += loss
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
# if ind % args.vis == 0:
# namecat = 'Test'
# for na in name:
# img_name = na.split('/')[-1].split('.')[0]
# namecat = namecat + img_name + '+'
# # vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
# vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat + args.dataset + 'epoch+' + str(epoch) + '.jpg'), reverse=False)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
tempvessel = eval_seg(pred[:, 0:1, :, :], mask_vessel, threshold)
mix_resvessel = tuple([sum(a) for a in zip(mix_resvessel, tempvessel)])
metrics_tensorvessel = torch.tensor(mix_resvessel, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorvessel = reduce_value(metrics_tensorvessel, world_size, average=True)
iou_vessel, dice_vessel = metrics_tensorvessel.tolist()
temparteriole = eval_seg(pred[:, 1:2, :, :], mask_arteriole, threshold)
mix_resarteriole = tuple([sum(a) for a in zip(mix_resarteriole, temparteriole)])
metrics_tensorarteriole = torch.tensor(mix_resarteriole, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorarteriole = reduce_value(metrics_tensorarteriole, world_size, average=True)
iou_arteriole, dice_arteriole = metrics_tensorarteriole.tolist()
tempvenule = eval_seg(pred[:, 2:3, :, :], mask_venule, threshold)
mix_resvenule = tuple([sum(a) for a in zip(mix_resvenule, tempvenule)])
metrics_tensorvenule = torch.tensor(mix_resvenule, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensorvenule = reduce_value(metrics_tensorvenule, world_size, average=True)
iou_venule, dice_venule = metrics_tensorvenule.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
return tot / n_val, tuple([iou_vessel / n_val, dice_vessel / n_val]), tuple([iou_arteriole / n_val, dice_arteriole / n_val]), tuple([iou_venule / n_val, dice_venule / n_val])
def seg_sam(args, net: nn.Module, train_loader, device=None):
hard = 0
epoch_loss = 0
net.eval()
n_tra = len(train_loader)
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
with torch.no_grad():
imge = net.image_encoder(imgs)
# imge= net.image_encoder(imgs)
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
mask = masks[i]
namecat = name[i].split('/')[-1].split('.')[0]
iter_mask(pred_single, mask,
os.path.join(args.path_helper['sample_path'], namecat + '.png'))
pbar.update()
return
def reduce_value(value, world_size, average=True):
if world_size < 2: # 单GPU的情况
return value
with torch.no_grad():
dist.all_reduce(value) # 对不同设备之间的value求和
if average: # 如果需要求平均,获得多块GPU计算loss的均值
value /= world_size
return value
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xuhengyuplus/SAM-Adapt.git
[email protected]:xuhengyuplus/SAM-Adapt.git
xuhengyuplus
SAM-Adapt
SAM-Adapt
multout

搜索帮助