代码拉取完成,页面将自动刷新
import cv2
import numpy as np
import torch
from faster_rcnn import network
from faster_rcnn.faster_rcnn import FasterRCNN
from faster_rcnn.utils.timer import Timer
def test():
import os
im_file = 'demo/004545.jpg'
# im_file = 'data/VOCdevkit2007/VOC2007/JPEGImages/009036.jpg'
# im_file = '/media/longc/Data/data/2DMOT2015/test/ETH-Crossing/img1/000100.jpg'
image = cv2.imread(im_file)
model_file = 'models/saved_model3/faster_rcnn_200000.pth'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch3/faster_rcnn_100000.h5'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch2/faster_rcnn_2000.h5'
detector = FasterRCNN()
detector.load_state_dict(torch.load(model_file))
detector.cuda()
detector.eval()
print('load model successfully!')
# network.save_net(r'/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5', detector)
# print('save model succ')
t = Timer()
t.tic()
# image = np.zeros(shape=[600, 800, 3], dtype=np.uint8) + 255
dets, scores, classes = detector.detect(image, 0.7)
runtime = t.toc()
print('total spend: {}s'.format(runtime))
im2show = np.copy(image)
for i, det in enumerate(dets):
det = tuple(int(x) for x in det)
cv2.rectangle(im2show, det[0:2], det[2:4], (255, 205, 51), 2)
cv2.putText(im2show, '%s: %.3f' % (classes[i], scores[i]), (det[0], det[1] + 15), cv2.FONT_HERSHEY_PLAIN,
1.0, (0, 0, 255), thickness=1)
cv2.imwrite(os.path.join('demo', 'out.jpg'), im2show)
cv2.imshow('demo', im2show)
cv2.waitKey(0)
if __name__ == '__main__':
test()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。