1 Star 0 Fork 0

xxxqhloveu/RepOptimizers

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 6.48 KB
一键复制 编辑 原始数据 按行查看 历史
Xiaohan Ding 提交于 2022-12-12 10:30 . update RepOpt-GhostNet
# --------------------------------------------------------
# Re-parameterizing Your Optimizers rather than Architectures (https://arxiv.org/abs/2205.15242)
# Github source: https://github.com/DingXiaoH/RepOptimizers
# Licensed under The MIT License [see LICENSE for details]
# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
# --------------------------------------------------------
import os
import torch
import torch.distributed as dist
import numpy as np
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def unwrap_model(model):
"""Remove the DistributedDataParallel wrapper if present."""
wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
return model.module if wrapped else model
def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
if model_ema is not None:
unwrap_model(model_ema).load_state_dict(checkpoint['ema'])
print('=================================================== EMAloaded')
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def load_weights(model, path):
checkpoint = torch.load(path, map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
unwrap_model(model).load_state_dict(checkpoint, strict=False)
print('=================== loaded from', path)
def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['ema'] = unwrap_model(model_ema).state_dict()
save_path = os.path.join(config.OUTPUT, 'latest.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['ema'] = unwrap_model(model_ema).state_dict()
if is_best:
best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth')
torch.save(save_state, best_path)
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter):
"""Update exponential moving average (ema) of model weights."""
update_period = cfg.TRAIN.EMA_UPDATE_PERIOD
if update_period is None or update_period == 0 or cur_iter % update_period != 0:
return
# Adjust alpha to be fairly independent of other parameters
total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE
adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period
# print('ema adjust', adjust)
alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust)
# During warmup simply copy over weights instead of using ema
alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha
# Take ema of all parameters (not just named parameters)
params = unwrap_model(model).state_dict()
for name, param in unwrap_model(model_ema).state_dict().items():
param.copy_(param * (1.0 - alpha) + params[name] * alpha)
def update_model_ema_v2(model, model_ema, ema_alpha):
params = unwrap_model(model).state_dict()
with torch.no_grad():
for name, param in unwrap_model(model_ema).state_dict().items():
param.copy_(param * (1.0 - ema_alpha) + params[name] * ema_alpha)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/corrshaw/RepOptimizers.git
[email protected]:corrshaw/RepOptimizers.git
corrshaw
RepOptimizers
RepOptimizers
main

搜索帮助