1 Star 0 Fork 0

desperadoxhy/SAM-Adapt

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dataset.py 44.60 KB
一键复制 编辑 原始数据 按行查看 历史
desperadoxhy 提交于 2024-05-29 00:59 . Remove .DS_Store files

""" train and test dataset
author jundewu
"""
import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
import cfg
from utils import random_click
import random
from monai.transforms import LoadImaged, Randomizable,LoadImage
args = cfg.parse_args()
class DualModal(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False, iter=False, iter_path=None, only_val=False):
self.path = os.path.join(data_path, 'RGB', mode)
if mode == 'Training' and iter:
self.path_mask = os.path.join(data_path, 'vessel', mode)
self.path_mask_2 = iter_path
# # 在迭代图片中取不到点就去原来的图片中取
self.label_list_2 = self.get_dirs_files(self.path_mask_2)
else:
self.path_mask = os.path.join(data_path, 'vessel', mode)
self.name_list = self.get_dirs_files(self.path)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mode = mode
self.iter = iter
self.only_val = only_val
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
name = self.name_list[index]
img_path = name
mask_name = self.label_list[index]
msk_path = mask_name
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
if self.mode == 'Training' and self.iter:
mask_path_2 = self.label_list_2[index]
mask_2 = Image.open(mask_path_2).convert('L')
mask_2 = mask_2.resize(newsize)
try:
pt = random_click(np.array(mask_2) / 255, point_label, inout=1)
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
elif self.mode == 'Test' and self.only_val:
name = name.split('/')[-1].split(".png")[0]
name = os.path.join('label_10_9', name + '.txt')
with open(name, 'r') as f:
lines = f.readlines()
second_line = lines[1]
floats = [float(value) * 1024 for value in second_line.split()]
pt = floats
else:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
# elif self.prompt == 'click_two':
# pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
# pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
# pt = np.concatenate((pt_pos, pt_nav)).reshape((2, 2))
# point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
# print(point_label)
# print(pt)
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModal3D(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode,name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm',self.mode, name)).convert('L')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# imgs = [img, img_570, img_610]
# img = np.stack(imgs, axis=0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModal3C(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode,name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm',self.mode, name)).convert('L')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# imgs = [img, img_570, img_610]
# img = np.stack(imgs, axis=0)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalNfold(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None, type='arteriole', prompt='click', plane=False):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.type = type
# 指明需要分割是的动脉和静脉,
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
point_label = 1
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', img_name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', img_name)).convert('L')
mask = Image.open(os.path.join(self.data_path, self.type, img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
img_570 = img_570.rotate(angle)
img_610 = img_610.rotate(angle)
mask = mask.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_570 = img_570.transpose(Image.FLIP_TOP_BOTTOM)
img_610 = img_610.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_570 = img_570.transpose(Image.FLIP_LEFT_RIGHT)
img_610 = img_610.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
elif self.prompt == 'noprompt':
pass
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
image_meta_dict = {'filename_or_obj': img_name}
if self.prompt == 'noprompt':
return {
'image': img,
'label': mask,
'p_label': point_label,
'image_meta_dict': image_meta_dict,
}
else:
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
class DualModalNfoldRGB(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None, type='arteriole', prompt='click', plane=False):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.type = type
# 指明需要分割是的动脉和静脉,
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
point_label = 1
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
mask = Image.open(os.path.join(self.data_path, self.type, img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask = mask.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
elif self.prompt == 'noprompt':
pass
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
image_meta_dict = {'filename_or_obj': img_name}
if self.prompt == 'noprompt':
return {
'image': img,
'label': mask,
'p_label': point_label,
'image_meta_dict': image_meta_dict,
}
else:
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
# 同时生成多张mask
class DualModalMultNfoldRGB(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', img_name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', img_name)).convert('L')
mask_venule= Image.open(os.path.join(self.data_path, 'venule', img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
image_meta_dict = {'filename_or_obj': img_name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
class DualModalMultNfold3C(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', img_name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', img_name)).convert('L')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', img_name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', img_name)).convert('L')
mask_venule= Image.open(os.path.join(self.data_path, 'venule', img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
img_570 = img_570.rotate(angle)
img_610 = img_610.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_570 = img_570.transpose(Image.FLIP_TOP_BOTTOM)
img_610 = img_610.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_570 = img_570.transpose(Image.FLIP_LEFT_RIGHT)
img_610 = img_610.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
point_label = 1
# 消融实验时取prompt
if args.net == 'sam_self_with_prompt':
try:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
image_meta_dict = {'filename_or_obj': img_name}
if args.net != 'sam_self':
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
if args.net == 'sam_self_with_prompt':
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
'pt': pt,
'p_label': point_label
}
else:
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
class DualModalMultRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Test', plane=False):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'RGB', mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', self.mode, name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', self.mode, name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'venule', self.mode, name)).convert('L')
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
"""
HRF 数据集
"""
class HRFRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'label')
self.label_list = self.get_dirs_files(self.path_mask)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'images', name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'label', name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'a', name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'v', name)).convert('L')
if 'train' in self.data_path:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalMult3C(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Test', plane=False):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'RGB', mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode, name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', self.mode, name)).convert('L')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', self.mode, name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', self.mode, name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'venule', self.mode, name)).convert('L')
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
# 消融实验时取prompt
if args.net == 'sam_self_with_prompt':
try:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
# elif self.prompt == 'click_two':
# pt_pos = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
# pt_nav = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
# pt = [pt_pos, pt_nav]
# point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# if args.net != 'sam_self' or args.net != 'sam_self_with_prompt':
# img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
if args.net == 'sam_self_with_prompt':
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
'pt': pt,
'p_label': point_label
}
else:
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class RGBTest(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
# 指明需要分割是的动脉和静脉
self.label_list = self.get_dirs_files(data_path)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
img_name = self.label_list[index]
name = os.path.basename(img_name)
img = Image.open(img_name).convert('RGB')
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class MultCTest(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
self.label_list = self.get_dirs_files(data_path)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
img_name = self.label_list[index]
name = os.path.basename(img_name)
img = Image.open(img_name).convert('RGB')
img_570 = Image.open(os.path.join('TestPic', 'slice', '570nm', name)).convert('L')
img_610 = Image.open(os.path.join('TestPic', 'slice', '610nm', name)).convert('L')
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class RandomDataset(Dataset):
def __init__(self, path):
self.data_path = self.get_dirs_files(path)
def __len__(self):
return len(self.data_path)
def __getitem__(self, index):
return self.data_path[index]
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
c = os.path.basename(c)
files.append(c)
return files
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xuhengyuplus/SAM-Adapt.git
[email protected]:xuhengyuplus/SAM-Adapt.git
xuhengyuplus
SAM-Adapt
SAM-Adapt
multout

搜索帮助