代码拉取完成,页面将自动刷新
import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import DataLoader
# from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
# from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
import pytorch_ssim
import models.sam.utils.transforms as samtrans
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
)
import torch
pos_weight = torch.ones([1]).to(device) * 2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1, 11, (args.b, 7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis=2, device=None):
hard = 0
epoch_loss = 0
ind = 0
vis = vis * 10
# train mode
net.train()
optimizer.zero_grad()
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
epoch_loss = 0
n_tra = len(train_loader)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True,
reduction='mean') # 损失函数将 Dice Loss(一种用于分割任务的评价指标)和交叉熵损失结起来
else:
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
print('Creating-----------------')
# pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
# print(imgs.shape)
# print(pt.shape)
imgs = imgs.repeat(1, 3, 1, 1)
# point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
'''Train'''
# for n, value in net.image_encoder.named_parameters():
# if "Adapter" not in n:
# value.requires_grad = False
imge = net.module.image_encoder(imgs)
with torch.no_grad():
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
# 求所有gpu loss 的均值
loss = reduce_value(loss, world_size, average=True)
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
if rank == 0:
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
namecat + args.dataset + 'epoch+' + str(
epoch) + '.jpg'), reverse=False, points=showp)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
pbar.update()
return epoch_loss / n_tra, tuple([iou / n_tra, dice / n_tra])
def seg_sam(args, net: nn.Module, train_loader, device=None):
hard = 0
epoch_loss = 0
net.eval()
n_tra = len(train_loader)
lossfunc = criterion_G # 二分类的交叉熵损失函数(Binary Cross Entropy with Logits Loss)
with tqdm(total=n_tra, desc=f'Epoch', unit='img') as pbar:
for pack in train_loader:
imgs = pack['image'].to(dtype=torch.float32, device=device)
masks = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
mask_type = torch.float32
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''对输入的 imgs 和 true_mask_ave 进行了一些预处理'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype=mask_type, device=device)
with torch.no_grad():
imge = net.image_encoder(imgs)
# imge= net.image_encoder(imgs)
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
mask = masks[i]
namecat = name[i].split('/')[-1].split('.')[0]
iter_mask(pred_single, mask,
os.path.join(args.path_helper['sample_path'], namecat + '.png'))
pbar.update()
return
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True, get_mask=False):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0, 0, 0, 0), (0, 0, 0, 0)
rater_res = [(0, 0, 0, 0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar_val:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype=torch.float32, device=device)
masksw = pack['label'].to(dtype=torch.float32, device=device)
if 'pt' not in pack:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:, :, buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[..., buoy:buoy + evl_ch]
masks = masksw[..., buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1, 3, 1, 1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
# c, w, h = imgs.size()
longsize = w if w >= h else h
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
# true_mask_ave = cons_tensor(true_mask_ave)
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
imgs = imgs.to(dtype=mask_type, device=device)
'''test'''
with torch.no_grad():
if args.distributed != 'none':
imge = net.module.image_encoder(imgs)
else:
imge = net.image_encoder(imgs)
if args.distributed != 'none':
se, de = net.module.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
else:
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
if args.distributed != 'none':
pred, _ = net.module.mask_decoder(
image_embeddings=imge,
image_pe=net.module.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
else:
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
if args.distributed != 'none':
tot += reduce_value(lossfunc(pred, masks), world_size, average=True)
else:
tot += lossfunc(pred, masks)
'''vis images'''
# if rank == 0 and ind % args.vis == 0:
if ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
# vis_image(imgs, pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
vis_image_self(imgs, pred, masks, os.path.join(args.path_helper['sample_path'],
namecat + args.dataset + 'epoch+' + str(
epoch) + '.jpg'), reverse=False,
points=showp)
if get_mask:
b, c, h, w = pred.size()
for i in range(b):
pred_single = pred[i]
namecat = name[i].split('/')[-1].split('.')[0]
save_mask(pred_single, os.path.join(args.path_helper['mask_path'], namecat + '.png'))
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
metrics_tensor = torch.tensor(mix_res, dtype=torch.float32, device=device)
if args.distributed != 'none':
metrics_tensor = reduce_value(metrics_tensor, world_size, average=True)
iou, dice = metrics_tensor.tolist()
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
pbar_val.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot / n_val, tuple([iou / n_val, dice / n_val]),
def reduce_value(value, world_size, average=True):
if world_size < 2: # 单GPU的情况
return value
with torch.no_grad():
dist.all_reduce(value) # 对不同设备之间的value求和
if average: # 如果需要求平均,获得多块GPU计算loss的均值
value /= world_size
return value
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。