1 Star 0 Fork 6

hxf233333/My-Fatigue-Driven-Detection-Based-on-SSD

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
detection.py 2.65 KB
一键复制 编辑 原始数据 按行查看 历史
Revan-github 提交于 2019-06-07 14:11 +08:00 . Add files via upload
import torch
from torch.autograd import Function
from utils import decode, nms
class Detect(Function):
"""At test time, Detect is the final layer of SSD. Decode location preds,
apply non-maximum suppression to location predictions based on conf
scores and threshold to a top_k number of output predictions for both
confidence score and locations.
"""
def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
self.num_classes = num_classes
self.background_label = bkg_label
self.top_k = top_k
# Parameters used in nms.
self.nms_thresh = nms_thresh
if nms_thresh <= 0:
raise ValueError('nms_threshold must be non negative.')
self.conf_thresh = conf_thresh
self.variance = (0.1,0.2)
def forward(self, loc_data, conf_data, prior_data):
"""
Args:
loc_data: (tensor) Loc preds from loc layers
Shape: [batch,num_priors*4]
conf_data: (tensor) Shape: Conf preds from conf layers
Shape: [batch*num_priors,num_classes]
prior_data: (tensor) Prior boxes and variances from priorbox layers
Shape: [1,num_priors,4]
"""
num = loc_data.size(0) # batch size
num_priors = prior_data.size(0)
output = torch.zeros(num, self.num_classes, self.top_k, 5)
conf_preds = conf_data.view(num, num_priors,
self.num_classes).transpose(2, 1)
# Decode predictions into bboxes.
for i in range(num):
decoded_boxes = decode(loc_data[i], prior_data, self.variance)
# For each class, perform nms
conf_scores = conf_preds[i].clone()
for cl in range(1, self.num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.dim() == 0:
continue
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
boxes = decoded_boxes[l_mask].view(-1, 4)
# idx of highest scoring and non-overlapping boxes per class
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
if count==0:
continue
output[i, cl, :count] = \
torch.cat((scores[ids[:count]].unsqueeze(1),
boxes[ids[:count]]), 1)
flt = output.contiguous().view(num, -1, 5)
_, idx = flt[:, :, 0].sort(1, descending=True)
_, rank = idx.sort(1)
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
return output
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hxf233333/Fatigue-Driven-Detection-Based-on-CNN.git
[email protected]:hxf233333/Fatigue-Driven-Detection-Based-on-CNN.git
hxf233333
Fatigue-Driven-Detection-Based-on-CNN
My-Fatigue-Driven-Detection-Based-on-SSD
master

搜索帮助