代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
import torch as t
import os, sys
import yaml
import traceback
#%%
def load_yaml(path):
with open(path, 'r', encoding='utf-8') as fp:
result = yaml.load(fp.read(), yaml.FullLoader)
return result
default_config_path = os.path.join(os.path.dirname(__file__), 'default.yaml')
def load_config(config_path=default_config_path, run_env=True):
if 'NCCL_ASYNC_ERROR_HANDLING' in os.environ:
cfg = Config_mgpus(config_path, run_env)
else:
cfg = Config(config_path, run_env)
print(cfg)
return cfg
class Config():
def __init__(self, config_path=None, run_env=True):
if config_path is None:
return
configs = load_yaml(default_config_path)
if config_path!=default_config_path:
custom_cfgs = load_yaml(config_path)
for k, v in configs.items():
if isinstance(v, dict):
v.update(custom_cfgs.get(k, {}))
elif isinstance(v, list):
v.extend(custom_cfgs.get(k, []))
for path in configs['ExtraPath']:
sys.path.append(path)
if run_env:
for k, v in configs['ExtraEnv'].items():
os.environ[k] = str(v)
self.load_args_from_dict(configs['DatasetConfig'])
self.load_args_from_dict(configs['ModelConfig'])
self.load_args_from_dict(configs['TrainerConfig'])
self.load_args_from_dict(configs['BaseConfig'])
for k,v in configs['ImportConfig'].items():
if k in ['DatasetConfig', 'ModelConfig', 'TrainerConfig'] and v:
tmp = load_yaml(v)
self.load_args_from_dict(tmp)
if hasattr(self, "seed") and self.seed:
t.random.torch.manual_seed(self.seed)
def load_args_from_dict(self, args):
for k,v in args.items():
setattr(self, k, v)
return args
def load_args(self, args):
self.load_args_from_dict(vars(args))
def _parse(self, **kwargs):
state_dict = self._state_dict()
for k, v in kwargs.items():
if self.print_newopt and k not in state_dict:
print('add new option: "%s = %s"' % (k, v))
setattr(self, k, v)
def _state_dict(self):
orig = [k for k in Config.__dict__.keys() if not k.startswith("_")]
keys = list(set(self.__dict__.keys()).union(set(orig)))
return {k: getattr(self, k) for k in keys}
@property
def device(self):
device = t.device(f"cuda:{self._gpu_idx}" \
if t.cuda.is_available() and self.use_cuda else "cpu")
return device
@property
def device_count(self):
return t.cuda.device_count()
@device_count.setter
def device_count(self, idx):
pass
@property
def gpu_idx(self):
return self._gpu_idx
@gpu_idx.setter
def gpu_idx(self, idx):
self._gpu_idx = idx
if self.use_cuda:
t.cuda.set_device(idx)
def __str__(self):
orig = [k for k in Config.__dict__.keys() if not k.startswith("_")]
keys = list(set(self.__dict__.keys()).union(set(orig)))
keys.sort()
s = ""
for k in keys:
if callable(getattr(self, k)) or k.startswith("_"): continue
s += f"{k}:{getattr(self, k)}\n"
return s
#%% if want to use multi-gpus for distributed calculate, you can use this class
# e.g.:
# CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
class Config_mgpus(Config):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
t.distributed.init_process_group(backend="nccl")
t.cuda.set_device(self.gpu_idx)
except Exception as e:
traceback.print_exc()
if self.gpu_idx!=0:
self.use_tsboard = False
self._lr = self.lr / self.device_count
self.is_parallel = True
@property
def device(self):
return t.device("cuda", self.gpu_idx)
@property
def gpu_idx(self):
return t.distributed.get_rank()
@property
def use_cuda(self):
return True
@gpu_idx.setter
def gpu_idx(self, idx):
print('muti-gpus config does not support gpu_idx setter')
@property
def lr(self):
return self._lr * self.device_count
@lr.setter
def lr(self, lr):
self._lr = lr
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。