代码拉取完成,页面将自动刷新
import numpy as np
from PIL import Image
import h5py
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
import torch
class SiameseMNIST(Dataset):
"""
Train: For each sample creates randomly a positive or a negative pair
Test: Creates fixed pairs for testing
"""
def __init__(self, mnist_dataset):
self.mnist_dataset = mnist_dataset
self.train = self.mnist_dataset.train
self.transform = self.mnist_dataset.transform
if self.train:
self.train_labels = self.mnist_dataset.train_labels
self.train_data = self.mnist_dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}
else:
# generate fixed pairs for testing
self.test_labels = self.mnist_dataset.test_labels
self.test_data = self.mnist_dataset.test_data
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}
random_state = np.random.RandomState(29)
positive_pairs = [[i,
random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
1]
for i in range(0, len(self.test_data), 2)]
negative_pairs = [[i,
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i].item()]))
)
]),
0]
for i in range(1, len(self.test_data), 2)]
self.test_pairs = positive_pairs + negative_pairs
def __getitem__(self, index):
if self.train:
target = np.random.randint(0, 2)
img1, label1 = self.train_data[index], self.train_labels[index].item()
if target == 1:
siamese_index = index
while siamese_index == index:
siamese_index = np.random.choice(self.label_to_indices[label1])
else:
siamese_label = np.random.choice(list(self.labels_set - set([label1])))
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
img2 = self.train_data[siamese_index]
else:
img1 = self.test_data[self.test_pairs[index][0]]
img2 = self.test_data[self.test_pairs[index][1]]
target = self.test_pairs[index][2]
img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
return (img1, img2), target
def __len__(self):
return len(self.mnist_dataset)
class TripletMNIST(Dataset):
"""
Train: For each sample (anchor) randomly chooses a positive and negative samples
Test: Creates fixed triplets for testing
"""
def __init__(self, mnist_dataset):
self.mnist_dataset = mnist_dataset
self.train = self.mnist_dataset.train
self.transform = self.mnist_dataset.transform
if self.train:
self.train_labels = self.mnist_dataset.train_labels
self.train_data = self.mnist_dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}
else:
self.test_labels = self.mnist_dataset.test_labels
self.test_data = self.mnist_dataset.test_data
# generate fixed triplets for testing
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}
random_state = np.random.RandomState(29)
triplets = [[i,
random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i].item()]))
)
])
]
for i in range(len(self.test_data))]
self.test_triplets = triplets
def __getitem__(self, index):
if self.train:
img1, label1 = self.train_data[index], self.train_labels[index].item()
positive_index = index
while positive_index == index:
positive_index = np.random.choice(self.label_to_indices[label1])
negative_label = np.random.choice(list(self.labels_set - set([label1])))
negative_index = np.random.choice(self.label_to_indices[negative_label])
img2 = self.train_data[positive_index]
img3 = self.train_data[negative_index]
else:
img1 = self.test_data[self.test_triplets[index][0]]
img2 = self.test_data[self.test_triplets[index][1]]
img3 = self.test_data[self.test_triplets[index][2]]
img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
img3 = Image.fromarray(img3.numpy(), mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
img3 = self.transform(img3)
return (img1, img2, img3), []
def __len__(self):
return len(self.mnist_dataset)
class BalancedBatchSampler(BatchSampler):
"""
BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
Returns batches of size n_classes * n_samples
"""
def __init__(self, labels, n_classes, n_samples):
self.labels = labels
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.n_dataset = len(self.labels)
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < self.n_dataset:
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return self.n_dataset // self.batch_size
class stateMNIST(Dataset):
def __init__(self, filename, dataset_name, transform = None):
self.filename = filename
self.transform = transform
data_name = 'train_data' if dataset_name == "train" else 'test_data'
label_name = 'train_labels' if dataset_name == "train" else 'test_labels'
with h5py.File(filename, 'r') as f:
self.data = f[data_name][:]
self.labels = f[label_name][:]
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return torch.from_numpy(self.data[index]), torch.tensor(self.labels[index], dtype=torch.long)
if __name__ == "__main__":
dataset_path = 'data\strongswan_v1\data.h5'
state_dataset = stateMNIST(dataset_path, 'train')
print(len(state_dataset))
x,y = state_dataset[0]
print(x.shape)
print(y)
state_dataset = stateMNIST(dataset_path, 'test')
print(len(state_dataset))
x,y = state_dataset[0]
print(x.shape)
print(y)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。