1 Star 0 Fork 1

Mcflyy/遥感图像

forked from 啊涛涛/遥感图像 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
yolo.py 14.35 KB
一键复制 编辑 原始数据 按行查看 历史
bishe 提交于 2022-03-27 15:43 . 目标检测
import colorsys
import os
import time
import numpy as np
from keras import backend as K
from PIL import ImageDraw, ImageFont
from nets.yolo import yolo_body
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
resize_image)
from utils.utils_bbox import DecodeBox
class YOLO(object):
_defaults = {
#--------------------------------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
#
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------#
"model_path" : 'model_data/yolov4_mobilenet_v1_voc.h5',
"classes_path" : 'model_data/voc_classes.txt',
#---------------------------------------------------------------------#
# anchors_path代表先验框对应的txt文件,一般不修改。
# anchors_mask用于帮助代码找到对应的先验框,一般不修改。
#---------------------------------------------------------------------#
"anchors_path" : 'model_data/yolo_anchors.txt',
"anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
#---------------------------------------------------------------------#
# 输入图片的大小,必须为32的倍数。
#---------------------------------------------------------------------#
"input_shape" : [416, 416],
#---------------------------------------------------------------------#
# 目标检测网络所使用的主干
#---------------------------------------------------------------------#
"backbone" : 'mobilenetv1',
#---------------------------------------------------------------------#
# 通道的缩放系数
#---------------------------------------------------------------------#
"alpha" : 1,
#---------------------------------------------------------------------#
# 只有得分大于置信度的预测框会被保留下来
#---------------------------------------------------------------------#
"confidence" : 0.5,
#---------------------------------------------------------------------#
# 非极大抑制所用到的nms_iou大小
#---------------------------------------------------------------------#
"nms_iou" : 0.3,
"max_boxes" : 100,
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image" : False,
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化yolo
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
#---------------------------------------------------#
# 获得种类和先验框的数量
#---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
self.input_image_shape = K.placeholder(shape=(2, ))
self.sess = K.get_session()
self.boxes, self.scores, self.classes = self.generate()
#---------------------------------------------------#
# 载入模型
#---------------------------------------------------#
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
self.yolo_model = yolo_body([None, None, 3], self.anchors_mask, self.num_classes, self.backbone, self.alpha)
self.yolo_model.load_weights(self.model_path)
print('{} model, anchors, and classes loaded.'.format(model_path))
#---------------------------------------------------------#
# 在yolo_eval函数中,我们会对预测结果进行后处理
# 后处理的内容包括,解码、非极大抑制、门限筛选等
#---------------------------------------------------------#
boxes, scores, classes = DecodeBox(
self.yolo_model.output,
self.anchors,
self.num_classes,
self.input_image_shape,
self.input_shape,
anchor_mask = self.anchors_mask,
max_boxes = self.max_boxes,
confidence = self.confidence,
nms_iou = self.nms_iou,
letterbox_image = self.letterbox_image
)
return boxes, scores, classes
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image, crop = False):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,并进行归一化
#---------------------------------------------------------#
image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0})
print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
#---------------------------------------------------------#
# 设置字体与边框厚度
#---------------------------------------------------------#
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
#---------------------------------------------------------#
# 是否进行目标的裁剪
#---------------------------------------------------------#
if crop:
for i, c in list(enumerate(out_classes)):
top, left, bottom, right = out_boxes[i]
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
dir_save_path = "img_crop"
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
crop_image = image.crop([left, top, right, bottom])
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
print("save crop_" + str(i) + ".png to " + dir_save_path)
#---------------------------------------------------------#
# 图像绘制
#---------------------------------------------------------#
for i, c in list(enumerate(out_classes)):
predicted_class = self.class_names[int(c)]
box = out_boxes[i]
score = out_scores[i]
top, left, bottom, right = box
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
print(label, top, left, bottom, right)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw
return image
def get_FPS(self, image, test_interval):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,并进行归一化
#---------------------------------------------------------#
image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0})
t1 = time.time()
for _ in range(test_interval):
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0})
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time
def get_map_txt(self, image_id, image, class_names, map_out_path):
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,并进行归一化
#---------------------------------------------------------#
image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0
})
for i, c in enumerate(out_classes):
predicted_class = self.class_names[int(c)]
score = str(out_scores[i])
top, left, bottom, right = out_boxes[i]
if predicted_class not in class_names:
continue
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return
def close_session(self):
self.sess.close()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mcflyy/remote-sensing-image.git
[email protected]:mcflyy/remote-sensing-image.git
mcflyy
remote-sensing-image
遥感图像
master

搜索帮助