1 Star 2 Fork 0

Stefan/wrc

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ssd_mobilenet_ros.py 11.58 KB
一键复制 编辑 原始数据 按行查看 历史
Stefan 提交于 2021-11-19 19:50 . wrc 2021 demo
#!/usr/bin/python
#coding=utf-8
import rospy
import cv2
from cv_bridge import CvBridge
import colorsys
import os
import time
import numpy as np
from keras import backend as K
from keras.applications.imagenet_utils import preprocess_input
from PIL import Image, ImageDraw, ImageFont
from nets.ssd import SSD300
from utils.utils import BBoxUtility, letterbox_image, ssd_correct_boxes
from sensor_msgs.msg import Image as ros_img
from ros_openvino.msg import Object,Objects
cvb=CvBridge()
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改2个参数
# model_path和classes_path都需要修改!
# 如果出现shape不匹配
# 一定要注意训练时的NUM_CLASSES、
# model_path和classes_path参数的修改
#--------------------------------------------#
class SSD(object):
_defaults = {
"model_path" : 'model_data/wrc_model.h5',
"classes_path" : 'model_data/wrc_classes.txt',
"input_shape" : (300, 300, 3),
"confidence" : 0.3,
"nms_iou" : 0.45,
'anchors_size' : [30,60,111,162,213,264,315],
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image" : False,
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化ssd
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
self.class_names = self._get_class()
self.sess = K.get_session()
self.generate()
self.bbox_util = BBoxUtility(self.num_classes, nms_thresh=self.nms_iou)
rospy.init_node('object_detection_wrc2021', anonymous=False)
#self.object_result_pub=rospy.Publisher("object_detection/result",Object,queue_size=1)
self.object_img_pub=rospy.Publisher("object_detection/output_image",ros_img,queue_size=10)
self.object_results_pub=rospy.Publisher("object_detection/results",Objects,queue_size=1)
self.object_dectction_msg=Object()
self.object_results_msg=Objects()
self.img_input=np.zeros((640,480,3),np.uint8)
rospy.Subscriber("camera/rgb/image_raw", ros_img, self.read_img)
self.detect_image()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def _get_class(self):
classes_path = os.path.expanduser(self.classes_path)
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
#---------------------------------------------------#
# 载入模型
#---------------------------------------------------#
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
#-------------------------------#
# 计算总的类的数量
#-------------------------------#
self.num_classes = len(self.class_names) + 1
#-------------------------------#
# 载入模型与权值
#-------------------------------#
self.ssd_model = SSD300(self.input_shape, self.num_classes, anchors_size=self.anchors_size)
self.ssd_model.load_weights(self.model_path, by_name=True)
print('{} model, anchors, and classes loaded.'.format(model_path))
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def read_img(self,image):
frame=cvb.imgmsg_to_cv2(image,desired_encoding="passthrough")
self.img_input=frame
def detect_image(self):
while not rospy.is_shutdown():
fps = 0.0
t1 = time.time()
frame = cv2.cvtColor(self.img_input,cv2.COLOR_BGR2RGB)
frame2=frame.copy()
#转变成Image类型图像
frame = Image.fromarray(np.uint8(frame))
#通道转换,去掉透明通道 4通道变成3通道
image = frame.convert('RGB')
image_shape = np.array(np.shape(image)[0:2])
# #---------------------------------------------------------#
# # 给图像增加灰条,实现不失真的resize
# # 也可以直接resize进行识别
# #---------------------------------------------------------#
if self.letterbox_image:
crop_img = np.array(letterbox_image(image, (self.input_shape[1],self.input_shape[0])))
else:
crop_img = image.resize((self.input_shape[1],self.input_shape[0]), Image.BICUBIC)
photo = np.array(crop_img,dtype = np.float64)
# #-----------------------------------------------------------#
# # 图片预处理,归一化。
# #-----------------------------------------------------------#
photo = preprocess_input(np.reshape(photo,[1,self.input_shape[0], self.input_shape[1], 3]))
preds = self.ssd_model.predict(photo)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
results = self.bbox_util.detection_out(preds, confidence_threshold=self.confidence)
# #--------------------------------------#
# # 如果没有检测到物体,则返回原图
# #--------------------------------------#
if len(results[0])<=0:
#frame = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
#frame = cv2.cvtColor(np.asarray(image))
object_img=cvb.cv2_to_imgmsg(frame2,"bgr8")
self.object_img_pub.publish(object_img)
continue
#-----------------------------------------------------------#
# 筛选出其中得分高于confidence的框
#-----------------------------------------------------------#
det_label = results[0][:, 0]
det_conf = results[0][:, 1]
det_xmin, det_ymin, det_xmax, det_ymax = results[0][:, 2], results[0][:, 3], results[0][:, 4], results[0][:, 5]
top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
top_conf = det_conf[top_indices]
top_label_indices = det_label[top_indices].tolist()
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
#-----------------------------------------------------------#
# 去掉灰条部分
#-----------------------------------------------------------#
if self.letterbox_image:
boxes = ssd_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
else:
top_xmin = top_xmin * image_shape[1]
top_ymin = top_ymin * image_shape[0]
top_xmax = top_xmax * image_shape[1]
top_ymax = top_ymax * image_shape[0]
boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1)
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
self.object_results_msg=Objects()
for i, c in enumerate(top_label_indices):
self.object_dectction_msg=Object()
predicted_class = self.class_names[int(c)-1]
# print(predicted_class)
score = top_conf[i]
top, left, bottom, right = boxes[i]
top = top - 5
left = left - 5
bottom = bottom + 5
right = right + 5
top = max(0, np.floor(top + 0.5).astype('int32'))
left = max(0, np.floor(left + 0.5).astype('int32'))
bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
self.object_dectction_msg.label=predicted_class
self.object_dectction_msg.confidence=score
self.object_dectction_msg.x=left
self.object_dectction_msg.y=top
self.object_dectction_msg.width=right-left
self.object_dectction_msg.height=bottom-top
self.object_results_msg.objects.append(self.object_dectction_msg)
# 画框框
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
# print(label, top, left, bottom, right)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle(
[left + i, top + i, right - i, bottom - i],
outline=self.colors[int(c)-1])
draw.rectangle(
[tuple(text_origin), tuple(text_origin + label_size)],
fill=self.colors[int(c)-1])
draw.text(text_origin, str(label), fill=(0, 0, 0), font=font)
del draw
frame = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
#print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
object_img=cvb.cv2_to_imgmsg(frame,"rgb8")
self.object_img_pub.publish(object_img)
self.object_results_pub.publish(self.object_results_msg)
if __name__ == '__main__':
try:
SSD()
rospy.spin()
except rospy.ROSInterruptException:
rospy.loginfo("Follower node terminated.")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/stefantasy/wrc_2021_demo.git
[email protected]:stefantasy/wrc_2021_demo.git
stefantasy
wrc_2021_demo
wrc
devel

搜索帮助