1 Star 0 Fork 1

Tomcat/my_mask

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
run.py 17.49 KB
一键复制 编辑 原始数据 按行查看 历史
tom 提交于 8个月前 . 1.训练过程的代码完成
import json
import os
import cv2
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from torch import optim
from torchvision import transforms
from torchvision.models import ResNet50_Weights
from torchvision.models.detection import mask_rcnn
from torchvision.models.detection.faster_rcnn import FastRCNNConvFCHead
from torchvision.ops import box_iou
from config import Config
import my_transforms
import utils
from coco_utils import ConvertCocoPolysToMask, MyCocoDetection
from engine import evaluate
from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
# 定义标准差 均值
mean = torch.tensor([0.32648515701293945, 0.29732969403266907, 0.21654412150382996])
std = torch.tensor([0.4628050923347473, 0.4247659146785736, 0.3108766973018646])
class Run:
def __init__(self):
super().__init__()
# 0.加载默认配置文件
self.config = Config()
self.config.print()
# 1.加载模型.创建优化器,学习率调度器
self.model = None
self.optimizer = None
self.lr_scheduler = None
self.create_model()
# 2.创建transforms,读取数据集的数据转换为模型的输入
# 注意,这里的Compose是自定义的Compose,会处理2个参数,而不是一个
self.val_transforms = None
self.train_transforms = None
self.predict_transforms = None
self.create_transforms()
# 3.创建data_loader
self.data_loader = None
self.data_loader_test = None
self.create_data_loader()
# 4.获取标记分类的json文件
assert os.path.exists(self.config.LABEL_JSON_PATH), "json file {} dose not exist.".format(self.config.LABEL_JSON_PATH)
with open(self.config.LABEL_JSON_PATH, 'r') as json_file:
self.category_index = json.load(json_file)
def create_model(self):
box_head = FastRCNNConvFCHead(
(256, 7, 7), [256, 256, 256, 256], [1024],
norm_layer=torch.nn.BatchNorm2d
)
# 加载模型
self.model = mask_rcnn.maskrcnn_resnet50_fpn(num_classes=self.config.NUM_CLASSES,
weights=None,
backbone_weights=ResNet50_Weights.IMAGENET1K_V2,
image_mean=mean,
image_std=std,
trainable_backbone_layers=self.config.TRAINABLE_BACKBONE_LAYERS,
rpn_anchor_generator=self.config.RPN_ANCHOR_GENERATOR,
box_head=box_head,
box_nms_thresh=0.2
)
if self.config.RESUME_PTH is not None:
dict_all = torch.load(self.config.RESUME_PTH, map_location='cpu')
weights_dict = dict_all["model"]
self.model.load_state_dict(weights_dict)
config_true = dict_all["config"]
# 使用上一次的配置执行运行
print("!!change config form resume pth!!")
self.config = config_true
self.config.print()
self.model.to(self.config.DEVICE)
# 不执行权重衰减
parameters = [p for p in self.model.parameters() if p.requires_grad]
# 创建GSD的优化器,也可用AdamW
# self.optimizer = optim.SGD(parameters, lr=self.config.LR,
# momentum=self.config.MOMENTUM,
# weight_decay=self.config.WEIGHT_DECAY)
self.optimizer = optim.AdamW(parameters, lr=self.config.LR, weight_decay=self.config.WEIGHT_DECAY)
self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=self.config.LR_SCHEDULER_STEP_SIZE,
gamma=self.config.LR_SCHEDULER_GAMMA)
def create_transforms(self):
self.train_transforms = my_transforms.Compose([
# 处理target信息,转化mask,以及将信息合并为tensor
ConvertCocoPolysToMask(),
my_transforms.PILToTensor(),
my_transforms.RandomHorizontalFlip(0.5),
# 图像转换到0-1之间的float类型
my_transforms.ToDtype(torch.float, True)])
self.val_transforms = my_transforms.Compose([
ConvertCocoPolysToMask(),
my_transforms.PILToTensor(),
my_transforms.ToDtype(torch.float, True)])
self.predict_transforms = transforms.Compose([transforms.ToTensor()])
@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))
def create_data_loader(self):
dataset = MyCocoDetection(self.config.TRAIN_IMG_FOLDER, self.config.TRAIN_ANN_FILE,
transforms=self.train_transforms)
dataset_val = MyCocoDetection(self.config.VAL_IMG_FOLDER, self.config.VAL_ANN_FILE,
transforms=self.val_transforms)
# 创建采样器,是否需要将相似宽高比的图像分配到一组组成batch
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_val)
if self.config.ASPECT_RATIO_GROUP_FACTOR >= 0:
group_ids = create_aspect_ratio_groups(dataset, k=self.config.ASPECT_RATIO_GROUP_FACTOR)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, self.config.BATCH_SIZE)
else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, self.config.BATCH_SIZE, drop_last=False)
nw = min([os.cpu_count(), self.config.BATCH_SIZE if self.config.BATCH_SIZE > 1 else 0, 8]) # number of workers
print('Using %g dataloader workers' % nw)
self.data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=nw, collate_fn=self.collate_fn
)
self.data_loader_test = torch.utils.data.DataLoader(
dataset_val, batch_size=1, sampler=test_sampler, num_workers=nw, collate_fn=self.collate_fn
)
def train_model(self):
for epoch in range(self.config.EPOCHS_NUM):
# 执行训练
self.model.train()
self.train_one_epoch(epoch)
self.lr_scheduler.step()
# 存储训练权重结果
checkpoint = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"epoch": epoch,
"config": self.config
}
torch.save(checkpoint, self.config.OUT_DIR + "/model_{}.pth".format(epoch))
# 执行验证
evaluate(self.model, self.data_loader_test, self.config.DEVICE)
def train_one_epoch(self, epoch):
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = f"Epoch: [{epoch}]"
# warmup训练,调整学习率
lr_scheduler = None
if epoch == 0:
warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(self.data_loader) - 1)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
self.optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
# 执行minibatch梯度下降
for images, targets in metric_logger.log_every(self.data_loader, self.config.PRINT_FREQ, header):
# 将images和targets列表中的元素移动到指定的设备上
images = list(image.to(self.config.DEVICE) for image in images)
targets = [{k: v.to(self.config.DEVICE) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in
targets]
with torch.set_grad_enabled(True):
# 执行前向传播,计算loss
loss_dict = self.model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# 清除梯度
self.optimizer.zero_grad()
# 执行反向传播
losses.backward()
# 更新优化器权重
self.optimizer.step()
# 第一轮使用warmup训练方式
if lr_scheduler is not None:
lr_scheduler.step()
metric_logger.update(loss=losses)
metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
def calc_mean_std(self):
# 初始化均值和方差的变量
mean = torch.zeros(3)
var = torch.zeros(3)
count = 0
# 遍历整个数据集
for images, _ in self.data_loader:
for image in images:
# 将图像数据从形状(C, H, W)展平为(C, H*W)
image = image.view(3, -1)
# 累加像素值
batch_mean = image.mean(dim=1)
mean += batch_mean
# 累加像素值的平方
batch_var = (image ** 2).mean(dim=1)
var += batch_var
# 更新计数
count += image.size(0)
# 计算整个数据集的均值和方差
mean /= count
var = (var / count) - (mean ** 2)
std = torch.sqrt(var)
print(f"Mean: {mean.tolist()}")
print(f"Std: {std.tolist()}")
@torch.inference_mode()
def predict(self, weights_path, img_path, threshold_box=0.5, threshold_score=0.3):
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
weights_dict = torch.load(weights_path, map_location='cpu')
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
self.model.load_state_dict(weights_dict)
self.model.to(self.config.DEVICE)
# load image
assert os.path.exists(img_path), f"{img_path} does not exits."
original_img = Image.open(img_path).convert('RGB')
filename = os.path.basename(img_path)
img = self.predict_transforms(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
self.model.eval() # 进入验证模式
with torch.no_grad():
t_start = utils.time_synchronized()
predictions = self.model(img.to(self.config.DEVICE))[0]
self.nms_without_classification(predictions)
t_end = utils.time_synchronized()
print("inference+NMS time: {}".format(t_end - t_start))
numpy_image = np.moveaxis(img.cpu().numpy().squeeze(0), 0, -1)
opencv_image = cv2.cvtColor((numpy_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
for box, label, scores, mask in self.nms_without_classification(predictions, threshold_box,
threshold_score):
print(f"label:{self.category_index[str(label)]} scores:{scores}")
color = tuple(np.random.randint(0, 255, 3).tolist())
x1, y1, x2, y2 = box.astype(np.uint)
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), color, 2)
cv2.putText(opencv_image, f"label:{self.category_index[str(label)]}_scores:{scores}", (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, color, 2)
for mm in mask:
ts = (mm * 255).astype("uint8") # 将掩码转换为8位无符号整数
contours, _ = cv2.findContours(ts, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(opencv_image, contours, -1, color, 2)
cv2.imwrite("predict_" + filename, opencv_image)
@torch.inference_mode()
def predict_dir(self, weights_path, img_path, threshold_box=0.5, threshold_score=0.3):
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
weights_dict = torch.load(weights_path, map_location='cpu')
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
self.model.load_state_dict(weights_dict)
self.model.to(self.config.DEVICE)
# load image
assert os.path.exists(img_path), f"{img_path} does not exits."
files = os.listdir(img_path)
for filename in files:
original_img = Image.open(img_path + "/" + filename).convert('RGB')
img = self.predict_transforms(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
self.model.eval() # 进入验证模式
with torch.no_grad():
t_start = utils.time_synchronized()
predictions = self.model(img.to(self.config.DEVICE))[0]
self.nms_without_classification(predictions)
t_end = utils.time_synchronized()
print("inference+NMS time: {}".format(t_end - t_start))
numpy_image = np.moveaxis(img.cpu().numpy().squeeze(0), 0, -1)
opencv_image = cv2.cvtColor((numpy_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
for box, label, scores, mask in self.nms_without_classification(predictions, threshold_box,
threshold_score):
print(f"label:{self.category_index[str(label)]} scores:{scores}")
color = tuple(np.random.randint(0, 255, 3).tolist())
x1, y1, x2, y2 = box.astype(np.uint)
cv2.rectangle(opencv_image, (x1, y1), (x2, y2), color, 2)
cv2.putText(opencv_image, f"label:{self.category_index[str(label)]}_scores:{scores}",
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, color, 2)
for mm in mask:
ts = (mm * 255).astype("uint8") # 将掩码转换为8位无符号整数
contours, _ = cv2.findContours(ts, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(opencv_image, contours, -1, color, 2)
cv2.imwrite(img_path + "/predict/" + filename, opencv_image)
@staticmethod
def nms_without_classification(predictions, threshold_box=0.5, threshold_score=0.3):
"""
过滤掉得分小于threshold_score的数据
由于nms默认是按照分类经行抑制的,但是这个场景下一个对象不可以被识别为多次,
故而做一次不区分分类的nms操作,只看iou(threshold_box),取得分更高的那个
Args:
predictions: 预测结果
threshold_box: iou大于threshold_box指挥保留一个(无论分类),
0代表 要只要有重合就排除掉得分低的那个
0.5代表 要只要有重合超过50%就排除掉得分低的那个
注意,同一类型的nms已经做过了,他的默认阈值是0.5,这里是2次抑制
需要效率时,请直接修改源码
threshold_score: 低于此得分的移除
Returns: 返回numpy表示的数据对象
"""
boxes = predictions["boxes"].to("cpu")
classes = predictions["labels"].to("cpu")
scores = predictions["scores"].to("cpu")
mask = predictions["masks"].to("cpu")
boxes_list = [row for row in boxes]
for i in range(len(boxes_list)):
for j in range(len(boxes_list)):
if i == j:
continue
a = boxes_list[i]
b = boxes_list[j]
if a is None or b is None:
continue
else:
if scores[i] < threshold_score:
boxes_list[i] = None
continue
iou = box_iou(a.unsqueeze(0), b.unsqueeze(0))
if iou < threshold_box:
continue
else:
if scores[i] > scores[j]:
boxes_list[j] = None
else:
boxes_list[i] = None
# 找出boxes_list中为None的索引
none_indices = [i for i, box in enumerate(boxes_list) if box is None]
# 创建一个与原始张量相同长度的布尔张量,并将None索引位置设为False
mask_tensor = torch.ones(len(boxes_list), dtype=torch.bool)
mask_tensor[none_indices] = False
r_boxes = boxes[mask_tensor]
r_classes = classes[mask_tensor]
r_scores = scores[mask_tensor]
r_mask = mask[mask_tensor]
return zip(r_boxes.numpy(), r_classes.numpy(), r_scores.numpy(), r_mask.numpy())
def main(args):
run = Run()
# run.calc_mean_std()
run.train_model()
# run.predict("./out/20240418/model_19.pth",
# run.predict("./out/2024042502/model_29.pth",
# "D:/py_project/dataset/aoi-fd/original/氧化/13B13101_638_1_79_428_100.856_52.004_Lens10X.JPG", 0.3)
# "D:/py_project/dataset/aoi-fd/original/划痕/13C18001_493_1_76_245_69.582_53.836_Lens10X.JPG", 0.3)
# "D:/py_project/dataset/aoi-fd/original/氧化/13B13101_532_1_129_429_101.057_26.223_Lens10X.JPG", 0.3)
# "D:/py_project/dataset/aoi-fd/original/氧化/13B13101_536_1_128_432_101.538_26.989_Lens10X.JPG", 0.3)
# run.predict_dir("./out/2024042502/model_29.pth", "./predict_test", 0.1, 0.3)
if __name__ == '__main__':
main(None)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/TomCoCo/my_mask.git
git@gitee.com:TomCoCo/my_mask.git
TomCoCo
my_mask
my_mask
master

搜索帮助