代码拉取完成,页面将自动刷新
同步操作将从 夏志新/My-Fatigue-Driven-Detection-Based-on-SSD 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
from torch.autograd import Variable
from detection import *
from ssd_net_vgg import *
from voc0712 import *
import torch.nn as nn
import numpy as np
import cv2
import utils
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
(44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
(148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
(227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
(188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229),(158, 218, 229),(158, 218, 229)]
net = SSD() # initialize SSD
net = torch.nn.DataParallel(net)
net.train(mode=False)
net.load_state_dict(torch.load('./weights/ssd_voc_5000_plus.pth',map_location=lambda storage, loc: storage))
img_id = 60
name='dnf_test'
image = cv2.imread('./'+name+'.jpg', cv2.IMREAD_COLOR)
x = cv2.resize(image, (300, 300)).astype(np.float32)
x -= (104.0, 117.0, 123.0)
x = x.astype(np.float32)
x = x[:, :, ::-1].copy()
# plt.imshow(x)
x = torch.from_numpy(x).permute(2, 0, 1)
xx = Variable(x.unsqueeze(0)) # wrap tensor in Variable
if torch.cuda.is_available():
xx = xx.cuda()
y = net(xx)
softmax = nn.Softmax(dim=-1)
detect = Detect(config.class_num, 0, 200, 0.01, 0.45)
priors = utils.default_prior_box()
loc,conf = y
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
detections = detect(
loc.view(loc.size(0), -1, 4),
softmax(conf.view(conf.size(0), -1,config.class_num)),
torch.cat([o.view(-1, 4) for o in priors], 0)
).data
labels = VOC_CLASSES
top_k=10
# plt.imshow(rgb_image) # plot the image for matplotlib
# scale each detection back up to the image
scale = torch.Tensor(image.shape[1::-1]).repeat(2)
for i in range(detections.size(1)):
j = 0
while detections[0,i,j,0] >= 0.4:
score = detections[0,i,j,0]
label_name = labels[i-1]
display_txt = '%s: %.2f'%(label_name, score)
pt = (detections[0,i,j,1:]*scale).cpu().numpy()
coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
color = colors_tableau[i]
cv2.rectangle(image,(pt[0],pt[1]), (pt[2],pt[3]), color, 2)
cv2.putText(image, display_txt, (int(pt[0]), int(pt[1]) + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255,255,255), 1, 8)
j+=1
#cv2.imshow('test',image)
#cv2.waitKey(100000)
print("------end-------")
cv2.imwrite(name+'_done.jpg',image)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。