代码拉取完成,页面将自动刷新
"""
Interaction head and its submodules
Fred Zhang <[email protected]>
The Australian National University
Australian Centre for Robotic Vision
"""
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn, Tensor
from typing import List
from collections import OrderedDict
from add_on.net import *
class InteractionHead(nn.Module):
"""
Interaction head that constructs and classifies box pairs
Parameters:
-----------
box_pair_predictor: nn.Module
Module that classifies box pairs
hidden_state_size: int (256)
Size of the object features
representation_size: int (512)
Size of the human-object pair features
num_channels: int (2048)
Number of channels in the global image features
num_classes: int (24)
Number of target classes
human_idx: int
The index of human/person class
object_class_to_target_class: List[list]
The set of valid action classes for each object type
"""
def __init__(self,
box_pair_predictor: nn.Module,
num_channels,
object_class_to_target_class: List[list],
args, num_query=500,
) -> None:
super().__init__()
self.device = args.device
self.args = args
self.whole_dec = False
self.num_query = num_query
self.box_pair_predictor = box_pair_predictor
hidden_state_size = args.hidden_dim
self.hidden_state_size = hidden_state_size
representation_size = args.repr_dim
self.representation_size = representation_size
self.num_classes = args.num_classes
self.human_idx = args.human_idx
self.object_class_to_target_class = object_class_to_target_class
self.nheads = args.nheads
self.Human_branch = Instance_Centric_Attention(input_num=256, hidden_num=512, num_classes=args.num_classes)
self.Object_branch = Instance_Centric_Attention(input_num=256, hidden_num=512, num_classes=args.num_classes)
def forward(self, images, resnet_features: OrderedDict, srcs,
detr_memory, image_shapes: Tensor, sample_wh:Tensor,
region_props: List[dict], masks=None, poses=None):
device = self.device
boxes_h_collated = []
boxes_o_collated = []
prior_collated = []
object_class_collated = []
attn_maps_collated = []
HOI_tokens_collated = []
prior_score = []
if self.whole_dec:
unary_attns = []
batch_num = detr_memory.shape[1]
unary_tokens = torch.zeros((batch_num, self.num_query, self.representation_size),
dtype=torch.float, device=device)
# 每一个sample
for b_idx, props in enumerate(region_props):
n = len(props['boxes'])
box = props['boxes']
score = props['scores']
label = props['labels']
unary_token = props['hidden_states']
is_human = (label == self.human_idx)
n_h = torch.sum(is_human)
# Permute human instances to the top
if not torch.all(label==self.human_idx):
h_idx = torch.nonzero(is_human).squeeze(1)
o_idx = torch.nonzero(is_human == False).squeeze(1)
perm = torch.cat([h_idx, o_idx])
box = box[perm]
score = score[perm]
label = label[perm]
unary_token = unary_token[perm]
if n_h == 0 or n <= 1:
boxes_h_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
boxes_o_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
object_class_collated.append(torch.zeros(0, device=device, dtype=torch.int64))
prior_collated.append(torch.zeros(2, 0, self.num_classes, device=device))
prior_score.append(torch.zeros(0, device=device))
HOI_token = torch.zeros((0, self.num_classes), dtype=torch.float, device=device)
HOI_tokens_collated.append(HOI_token)
attn_maps_collated.append(torch.zeros(0, 512, detr_memory.shape[-2], detr_memory.shape[-1], device=device))
continue
# Get the pairwise indices (N, N)
x, y = torch.meshgrid(torch.arange(n, device=device),
torch.arange(n, device=device))
if self.args.dataset == 'hicodet':
x_keep, y_keep = torch.nonzero(torch.logical_and(x!=y, x < n_h)).unbind(1)
elif self.args.dataset == 'vcoco':
x_keep, y_keep = torch.nonzero(x < n_h).unbind(1)
prior = self.compute_prior_scores(x_keep, y_keep, score, label)
prior_score.append(score)
boxes_h_collated.append(x_keep)
boxes_o_collated.append(y_keep)
object_class_collated.append(label[y_keep])
prior_collated.append(prior)
#### 之后全是pair(N x N)形式
x, y = x.flatten(), y.flatten()
# Compute spatial features (NxN, 36)
Human_feat, Human_attn = self.Human_branch(unary_token[x_keep], detr_memory[b_idx:b_idx+1])
Object_feat ,Object_attn= self.Object_branch(unary_token[y_keep], detr_memory[b_idx:b_idx+1])
HOI_tokens_collated.append(Human_feat * Object_feat)
attn_maps_collated.append(Human_attn.reshape(Human_attn.shape[0], Human_attn.shape[1], detr_memory.shape[-2] , detr_memory.shape[-1]))
logits = torch.cat(HOI_tokens_collated, dim=0).unsqueeze(0)
return logits, prior_collated, prior_score,\
boxes_h_collated, boxes_o_collated, \
object_class_collated, attn_maps_collated
def compute_prior_scores(self, x_keep: Tensor, y_keep: Tensor, scores: Tensor, object_class: Tensor) -> Tensor:
prior_h = torch.zeros(len(x_keep), self.num_classes, device=self.device)
prior_o = torch.zeros_like(prior_h)
# Raise the power of object detection scores during inference
p = 1.0 if self.training else 2.8
s_h = scores[x_keep].pow(p)
s_o = scores[y_keep].pow(p)
# Map object class index to target class index(过滤矩阵滤掉不正确的pair)
# Object class index to target class index is a one-to-many mapping
# Vcoco会自动不过滤没有宾语的动作
target_cls_idx = [self.object_class_to_target_class[obj.item()] if x_keep[i] != y_keep[i]
else range(self.num_classes) for i, obj in enumerate(object_class[y_keep])]
# Duplicate box pair indices for each target class
pair_idx = [i for i, tar in enumerate(target_cls_idx) for _ in tar]
# Flatten mapped target indices
flat_target_idx = [t for tar in target_cls_idx for t in tar]
prior_h[pair_idx, flat_target_idx] = s_h[pair_idx]#pairN, 24
prior_o[pair_idx, flat_target_idx] = s_o[pair_idx]#pairN, 24
return torch.stack([prior_h, prior_o])#2, pairN, 24
class Instance_Centric_Attention(nn.Module):
def __init__(self, input_num, hidden_num, num_classes
) -> None:
super().__init__()
self.hidden_num = hidden_num
self.Qfc = nn.Sequential(
nn.Linear(input_num, 2048),
nn.ReLU(inplace=True),
)
self.Qconv = nn.Sequential(
nn.Linear(input_num, hidden_num),
nn.ReLU(inplace=True),
)
self.Kconv = nn.Sequential(
nn.Conv2d(input_num, hidden_num, kernel_size=1),
nn.ReLU(inplace=True),
)
self.Vconv = nn.Sequential(
nn.Conv2d(input_num, hidden_num, kernel_size=1),
nn.ReLU(inplace=True),
)
self.fc = nn.Sequential(
nn.Linear(hidden_num, 1024),
nn.ReLU(inplace=True),
)
self.inst_fc = nn.Sequential(
nn.Linear(3072, 1024, bias=False),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(1024, 1024, bias=False),
nn.BatchNorm1d(1024),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(1024, num_classes),
)
def forward(self, stream, detr_memory):
w, h = detr_memory.shape[-2], detr_memory.shape[-1]
fc = self.Qfc(stream)
stream = self.Qconv(stream)
num_q = stream.shape[0]
context_K = self.Kconv(detr_memory).flatten(-2).squeeze()
context_V = self.Vconv(detr_memory).flatten(-2).squeeze()
score = torch.einsum('xy, yz -> xyz',stream, context_K)
score = F.softmax(score, -1)
feat = torch.einsum('xyz, yz -> xyz',score, context_V).reshape(num_q, -1, w, h)
feat = torch.nn.functional.adaptive_avg_pool2d(feat, (1,1)).squeeze().reshape(-1, self.hidden_num)
feat = self.fc(feat)
feat = torch.cat([fc, feat], -1)
feat = self.inst_fc(feat)
return feat, score
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。