代码拉取完成,页面将自动刷新
""" train and test dataset
author jundewu
"""
import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
import cfg
from utils import random_click
import random
from monai.transforms import LoadImaged, Randomizable,LoadImage
args = cfg.parse_args()
class DualModal(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', prompt='click', plane=False, iter=False, iter_path=None, only_val=False):
self.path = os.path.join(data_path, 'RGB', mode)
if mode == 'Training' and iter:
self.path_mask = os.path.join(data_path, 'vessel', mode)
self.path_mask_2 = iter_path
# # 在迭代图片中取不到点就去原来的图片中取
self.label_list_2 = self.get_dirs_files(self.path_mask_2)
else:
self.path_mask = os.path.join(data_path, 'vessel', mode)
self.name_list = self.get_dirs_files(self.path)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mode = mode
self.iter = iter
self.only_val = only_val
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
name = self.name_list[index]
img_path = name
mask_name = self.label_list[index]
msk_path = mask_name
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
if self.mode == 'Training' and self.iter:
mask_path_2 = self.label_list_2[index]
mask_2 = Image.open(mask_path_2).convert('L')
mask_2 = mask_2.resize(newsize)
try:
pt = random_click(np.array(mask_2) / 255, point_label, inout=1)
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
elif self.mode == 'Test' and self.only_val:
name = name.split('/')[-1].split(".png")[0]
name = os.path.join('label_10_9', name + '.txt')
with open(name, 'r') as f:
lines = f.readlines()
second_line = lines[1]
floats = [float(value) * 1024 for value in second_line.split()]
pt = floats
else:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
# elif self.prompt == 'click_two':
# pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
# pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
# pt = np.concatenate((pt_pos, pt_nav)).reshape((2, 2))
# point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
# print(point_label)
# print(pt)
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModal3D(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode,name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm',self.mode, name)).convert('L')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# imgs = [img, img_570, img_610]
# img = np.stack(imgs, axis=0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModal3C(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode,name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm',self.mode, name)).convert('L')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# imgs = [img, img_570, img_610]
# img = np.stack(imgs, axis=0)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Training', type='arteriole', prompt='click', plane=False):
self.data_path = data_path
self.type = type
# 指明需要分割是的动脉和静脉
self.path_mask = os.path.join(data_path, self.type, mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'click_two':
pt_pos = random_click(np.array(mask) / 255, point_label, inout=1)
pt_nav = random_click(np.array(mask) / 255, point_label, inout=0)
pt = [pt_pos, pt_nav]
point_label = np.array([1, 0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalNfold(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None, type='arteriole', prompt='click', plane=False):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.type = type
# 指明需要分割是的动脉和静脉,
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
point_label = 1
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', img_name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', img_name)).convert('L')
mask = Image.open(os.path.join(self.data_path, self.type, img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
img_570 = img_570.rotate(angle)
img_610 = img_610.rotate(angle)
mask = mask.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_570 = img_570.transpose(Image.FLIP_TOP_BOTTOM)
img_610 = img_610.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_570 = img_570.transpose(Image.FLIP_LEFT_RIGHT)
img_610 = img_610.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
elif self.prompt == 'noprompt':
pass
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
image_meta_dict = {'filename_or_obj': img_name}
if self.prompt == 'noprompt':
return {
'image': img,
'label': mask,
'p_label': point_label,
'image_meta_dict': image_meta_dict,
}
else:
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
class DualModalNfoldRGB(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None, type='arteriole', prompt='click', plane=False):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.type = type
# 指明需要分割是的动脉和静脉,
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
point_label = 1
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
mask = Image.open(os.path.join(self.data_path, self.type, img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask = mask.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
try:
pt = random_click(np.array(mask) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
elif self.prompt == 'auto':
x = random.randint(0, 1023)
y = random.randint(0, 1023)
pt = np.array([[x, y]])
point_label = np.array([1])
elif self.prompt == 'noprompt':
pass
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
image_meta_dict = {'filename_or_obj': img_name}
if self.prompt == 'noprompt':
return {
'image': img,
'label': mask,
'p_label': point_label,
'image_meta_dict': image_meta_dict,
}
else:
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
# 同时生成多张mask
class DualModalMultNfoldRGB(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', img_name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', img_name)).convert('L')
mask_venule= Image.open(os.path.join(self.data_path, 'venule', img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
image_meta_dict = {'filename_or_obj': img_name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
class DualModalMultNfold3C(Dataset):
def __init__(self, args, data_list, transform=None, transform_msk=None):
self.data_path = 'DualModalSlice128fold'
self.data_list = data_list
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
img_name = self.data_list[index]
img = Image.open(os.path.join(self.data_path, 'RGB', img_name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', img_name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', img_name)).convert('L')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', img_name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', img_name)).convert('L')
mask_venule= Image.open(os.path.join(self.data_path, 'venule', img_name)).convert('L')
if self.__len__() == 980:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
img_570 = img_570.rotate(angle)
img_610 = img_610.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
img_570 = img_570.transpose(Image.FLIP_TOP_BOTTOM)
img_610 = img_610.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img_570 = img_570.transpose(Image.FLIP_LEFT_RIGHT)
img_610 = img_610.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
point_label = 1
# 消融实验时取prompt
if args.net == 'sam_self_with_prompt':
try:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
image_meta_dict = {'filename_or_obj': img_name}
if args.net != 'sam_self':
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
if args.net == 'sam_self_with_prompt':
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
'pt': pt,
'p_label': point_label
}
else:
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
class DualModalMultRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Test', plane=False):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'RGB', mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', self.mode, name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', self.mode, name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'venule', self.mode, name)).convert('L')
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
"""
HRF 数据集
"""
class HRFRGB(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'label')
self.label_list = self.get_dirs_files(self.path_mask)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'images', name)).convert('RGB')
mask_vessel = Image.open(os.path.join(self.data_path, 'label', name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'a', name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'v', name)).convert('L')
if 'train' in self.data_path:
if random.random() < 0.5:
angle = random.randint(-15, 15)
img = img.rotate(angle)
mask_vessel = mask_vessel.rotate(angle)
mask_arteriole = mask_arteriole.rotate(angle)
mask_venule = mask_venule.rotate(angle)
# 以0.5的概率上下翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
mask_vessel = mask_vessel.transpose(Image.FLIP_TOP_BOTTOM)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_TOP_BOTTOM)
mask_venule = mask_venule.transpose(Image.FLIP_TOP_BOTTOM)
# 以0.5的概率左右翻转
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask_vessel = mask_vessel.transpose(Image.FLIP_LEFT_RIGHT)
mask_arteriole = mask_arteriole.transpose(Image.FLIP_LEFT_RIGHT)
mask_venule = mask_venule.transpose(Image.FLIP_LEFT_RIGHT)
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class DualModalMult3C(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None, mode='Test', plane=False):
self.data_path = data_path
# 获取动脉和静脉和血管的mask列表
self.path_mask = os.path.join(data_path, 'RGB', mode)
self.label_list = self.get_dirs_files(self.path_mask)
self.mode = mode
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
inout = 1
point_label = 1
"""Get the images"""
# name = self.name_list[index]
# img_path = name
# img_name = os.path.basename(img_path)
mask_name = self.label_list[index]
msk_path = mask_name
name = os.path.basename(msk_path)
img = Image.open(os.path.join(self.data_path, 'RGB', self.mode, name)).convert('RGB')
img_570 = Image.open(os.path.join(self.data_path, '570nm', self.mode, name)).convert('L')
img_610 = Image.open(os.path.join(self.data_path, '610nm', self.mode, name)).convert('L')
mask_vessel = Image.open(os.path.join(self.data_path, 'vessel', self.mode, name)).convert('L')
mask_arteriole = Image.open(os.path.join(self.data_path, 'arteriole', self.mode, name)).convert('L')
mask_venule = Image.open(os.path.join(self.data_path, 'venule', self.mode, name)).convert('L')
newsize = (self.img_size, self.img_size)
mask_vessel = mask_vessel.resize(newsize)
mask_arteriole = mask_arteriole.resize(newsize)
mask_venule = mask_venule.resize(newsize)
# 消融实验时取prompt
if args.net == 'sam_self_with_prompt':
try:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
pt = np.array([pt])
point_label = np.array([1])
except Exception:
pt = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
pt = np.array([pt])
point_label = np.array([0])
# elif self.prompt == 'click_two':
# pt_pos = random_click(np.array(mask_vessel) / 255, point_label, inout=1)
# pt_nav = random_click(np.array(mask_vessel) / 255, point_label, inout=0)
# pt = [pt_pos, pt_nav]
# point_label = np.array([1, 0])
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
if self.transform_msk:
mask_vessel = self.transform_msk(mask_vessel)
mask_arteriole = self.transform_msk(mask_arteriole)
mask_venule = self.transform_msk(mask_venule)
# masks = [mask, mask, mask]
# mask = np.stack(masks, axis=0)
# if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# mask = 1 - mask
# if args.net != 'sam_self' or args.net != 'sam_self_with_prompt':
# img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
if args.net == 'sam_self_with_prompt':
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
'pt': pt,
'p_label': point_label
}
else:
return {
'image': img,
'mask_vessel': mask_vessel,
'mask_arteriole': mask_arteriole,
'mask_venule': mask_venule,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class RGBTest(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
# 指明需要分割是的动脉和静脉
self.label_list = self.get_dirs_files(data_path)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
img_name = self.label_list[index]
name = os.path.basename(img_name)
img = Image.open(img_name).convert('RGB')
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class MultCTest(Dataset):
def __init__(self, args, data_path, transform=None, transform_msk=None):
self.label_list = self.get_dirs_files(data_path)
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.label_list)
def __getitem__(self, index):
img_name = self.label_list[index]
name = os.path.basename(img_name)
img = Image.open(img_name).convert('RGB')
img_570 = Image.open(os.path.join('TestPic', 'slice', '570nm', name)).convert('L')
img_610 = Image.open(os.path.join('TestPic', 'slice', '610nm', name)).convert('L')
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
img_570 = self.transform(img_570)
img_610 = self.transform(img_610)
torch.set_rng_state(state)
img = img[1, :, :].unsqueeze(0)
img = torch.cat([img, img_570, img_610], dim=0)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'image_meta_dict': image_meta_dict,
}
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
files.append(c)
return files
class RandomDataset(Dataset):
def __init__(self, path):
self.data_path = self.get_dirs_files(path)
def __len__(self):
return len(self.data_path)
def __getitem__(self, index):
return self.data_path[index]
def get_dirs_files(self, path):
files, dirs, dirs_ = [], [], []
dirs_.append(path)
while len(dirs_) != 0:
# 从列表中弹出最后一个元素
dir = dirs_.pop()
# 获取指定目录dir中所有文件和子目录的列表
for filename in os.listdir(dir):
c = os.path.join(dir, filename)
# 检查路径c是否是一个文件夹
if os.path.isdir(c):
dirs_.append(c)
dirs.append(c)
else:
c = os.path.basename(c)
files.append(c)
return files
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。