1 Star 0 Fork 1

Tomcat/my_mask

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
config.py 2.19 KB
一键复制 编辑 原始数据 按行查看 历史
tom 提交于 2024-05-14 13:31 . 1.训练过程的代码完成
import torch
from torchvision.models.detection.anchor_utils import AnchorGenerator
class Config:
def print(self):
"""
打印当前配置参数
"""
print("now Config is:")
for key, value in self.__dict__.items():
print(f"{key}: {value}")
def __init__(self):
self.DATASET_PATH = "D:/py_project/dataset"
# 存label的json,来源于annotations.json的categories有信息
self.LABEL_JSON_PATH = "D:/py_project/dataset/aoi-fd/indices.json"
# 分类数量+1(背景)
self.NUM_CLASSES = 5
# 创建采样器,是否需要将相似宽高比的图像分配到一组组成batch
self.ASPECT_RATIO_GROUP_FACTOR = 0
self.BATCH_SIZE = 4
# 训练的数据集
self.TRAIN_IMG_FOLDER = "D:/py_project/dataset/icon-coco"
self.TRAIN_ANN_FILE = "D:/py_project/dataset/icon-coco/annotations.json"
# 验证的数据集
self.VAL_IMG_FOLDER = "D:/py_project/dataset/icon-coco"
self.VAL_ANN_FILE = "D:/py_project/dataset/icon-coco/annotations.json"
# 定义设备
self.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 优化器参数
self.LR = 0.0005
self.MOMENTUM = 0.9
self.WEIGHT_DECAY = 0.0001
# lr调度器参数
self.LR_SCHEDULER_STEP_SIZE = 10
self.LR_SCHEDULER_GAMMA = 0.1
# 训练参数
# 训练轮数
self.EPOCHS_NUM = 40
# 骨干网络上训练(不冻结)的层
self.TRAINABLE_BACKBONE_LAYERS = 5
# rpn的anchor设定
anchor_sizes = ((32, 64), (64, 128), (128, 256), (256, 512), (512, 1024))
aspect_ratios = ((0.25, 0.5, 1.0, 2.0, 4.0),) * len(anchor_sizes)
# self.RPN_ANCHOR_GENERATOR = AnchorGenerator(anchor_sizes, aspect_ratios)
self.RPN_ANCHOR_GENERATOR = None
# 注意,一旦使用RESUME_PTH,则此类的所有配置,注意是所有,都将会被所指的PTH内的config所覆盖
# self.RESUME_PTH = "./out/model_0.pth"
self.RESUME_PTH = None
# 输出目录
self.OUT_DIR = "./out"
# 打印频率,多少的batch打印一次
self.PRINT_FREQ = 5
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/TomCoCo/my_mask.git
[email protected]:TomCoCo/my_mask.git
TomCoCo
my_mask
my_mask
master

搜索帮助