1 Star 0 Fork 0

jmc12138/state_siamese

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 8.74 KB
一键复制 编辑 原始数据 按行查看 历史
jmc12138 提交于 2024-12-26 10:26 +08:00 . i forget too many change
from itertools import combinations
import struct
import numpy as np
import torch
from PIL import Image
def pdist(vectors):
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
dim=1).view(-1, 1)
return distance_matrix
class PairSelector:
"""
Implementation should return indices of positive pairs and negative pairs that will be passed to compute
Contrastive Loss
return positive_pairs, negative_pairs
"""
def __init__(self):
pass
def get_pairs(self, embeddings, labels):
raise NotImplementedError
class AllPositivePairSelector(PairSelector):
"""
Discards embeddings and generates all possible pairs given labels.
If balance is True, negative pairs are a random sample to match the number of positive samples
"""
def __init__(self, balance=True):
super(AllPositivePairSelector, self).__init__()
self.balance = balance
def get_pairs(self, embeddings, labels):
labels = labels.cpu().data.numpy()
all_pairs = np.array(list(combinations(range(len(labels)), 2)))
all_pairs = torch.LongTensor(all_pairs)
positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
if self.balance:
negative_pairs = negative_pairs[torch.randperm(len(negative_pairs))[:len(positive_pairs)]]
return positive_pairs, negative_pairs
class HardNegativePairSelector(PairSelector):
"""
Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration,
matching the number of positive pairs.
"""
def __init__(self, cpu=True):
super(HardNegativePairSelector, self).__init__()
self.cpu = cpu
def get_pairs(self, embeddings, labels):
if self.cpu:
embeddings = embeddings.cpu()
distance_matrix = pdist(embeddings)
labels = labels.cpu().data.numpy()
all_pairs = np.array(list(combinations(range(len(labels)), 2)))
all_pairs = torch.LongTensor(all_pairs)
positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
negative_distances = negative_distances.cpu().data.numpy()
top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]
return positive_pairs, top_negative_pairs
class TripletSelector:
"""
Implementation should return indices of anchors, positive and negative samples
return np array of shape [N_triplets x 3]
"""
def __init__(self):
pass
def get_triplets(self, embeddings, labels):
raise NotImplementedError
class AllTripletSelector(TripletSelector):
"""
Returns all possible triplets
May be impractical in most cases
"""
def __init__(self):
super(AllTripletSelector, self).__init__()
def get_triplets(self, embeddings, labels):
labels = labels.cpu().data.numpy()
triplets = []
for label in set(labels):
label_mask = (labels == label)
label_indices = np.where(label_mask)[0]
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0]
anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs
# Add all negatives for all positive pairs
temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
for neg_ind in negative_indices]
triplets += temp_triplets
return torch.LongTensor(np.array(triplets))
def hardest_negative(loss_values):
hard_negative = np.argmax(loss_values)
return hard_negative if loss_values[hard_negative] > 0 else None
def random_hard_negative(loss_values):
hard_negatives = np.where(loss_values > 0)[0]
return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None
def semihard_negative(loss_values, margin):
semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0]
return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None
class FunctionNegativeTripletSelector(TripletSelector):
"""
For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet
Margin should match the margin used in triplet loss.
negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples
and return a negative index for that pair
"""
def __init__(self, margin, negative_selection_fn, cpu=True):
super(FunctionNegativeTripletSelector, self).__init__()
self.cpu = cpu
self.margin = margin
self.negative_selection_fn = negative_selection_fn
def get_triplets(self, embeddings, labels):
if self.cpu:
embeddings = embeddings.cpu()
distance_matrix = pdist(embeddings)
distance_matrix = distance_matrix.cpu()
labels = labels.cpu().data.numpy()
triplets = []
for label in set(labels):
label_mask = (labels == label)
label_indices = np.where(label_mask)[0]
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0]
anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs
anchor_positives = np.array(anchor_positives)
ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin
loss_values = loss_values.data.cpu().numpy()
hard_negative = self.negative_selection_fn(loss_values)
if hard_negative is not None:
hard_negative = negative_indices[hard_negative]
triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])
if len(triplets) == 0:
triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])
triplets = np.array(triplets)
return torch.LongTensor(triplets)
def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=hardest_negative,
cpu=cpu)
def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=random_hard_negative,
cpu=cpu)
def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=lambda x: semihard_negative(x, margin),
cpu=cpu)
def data2fig(data):
width, length = 300, 300
total_pixels = width * length
# 初始化一个空列表用于存储处理后的像素值
pixel_data = []
# 遍历输入的二进制数据列表,提取每个字节作为一个像素点
for item in data:
# 将每个数据项转换为bytearray以便于迭代
byte_array = bytearray(item)
# 添加到pixel_data中
pixel_data.extend(byte_array)
# 如果在处理过程中已经收集了足够的像素,则停止
if len(pixel_data) >= total_pixels:
break
# 如果数据不足,则用0填充至所需的总像素数量
if len(pixel_data) < total_pixels:
pixel_data.extend([0] * (total_pixels - len(pixel_data)))
# 确保我们只使用正好需要的像素数
pixel_data = pixel_data[:total_pixels]
# 创建一个新的300x300灰度图像
img = Image.new('L', (width, length))
# 将数据映射到图像中
img.putdata(pixel_data)
return np.array(img)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhangph12138/state_siamese.git
git@gitee.com:zhangph12138/state_siamese.git
zhangph12138
state_siamese
state_siamese
master

搜索帮助