1 Star 0 Fork 0

Ky1eYang/ezdlearn

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
config.py 4.42 KB
一键复制 编辑 原始数据 按行查看 历史
kyle 提交于 2023-05-22 03:00 . upgrade
# -*- coding: utf-8 -*-
import torch as t
import os, sys
import yaml
import traceback
#%%
def load_yaml(path):
with open(path, 'r', encoding='utf-8') as fp:
result = yaml.load(fp.read(), yaml.FullLoader)
return result
default_config_path = os.path.join(os.path.dirname(__file__), 'default.yaml')
def load_config(config_path=default_config_path, run_env=True):
if 'NCCL_ASYNC_ERROR_HANDLING' in os.environ:
cfg = Config_mgpus(config_path, run_env)
else:
cfg = Config(config_path, run_env)
print(cfg)
return cfg
class Config():
def __init__(self, config_path=None, run_env=True):
if config_path is None:
return
configs = load_yaml(default_config_path)
if config_path!=default_config_path:
custom_cfgs = load_yaml(config_path)
for k, v in configs.items():
if isinstance(v, dict):
v.update(custom_cfgs.get(k, {}))
elif isinstance(v, list):
v.extend(custom_cfgs.get(k, []))
for path in configs['ExtraPath']:
sys.path.append(path)
if run_env:
for k, v in configs['ExtraEnv'].items():
os.environ[k] = str(v)
self.load_args_from_dict(configs['DatasetConfig'])
self.load_args_from_dict(configs['ModelConfig'])
self.load_args_from_dict(configs['TrainerConfig'])
self.load_args_from_dict(configs['BaseConfig'])
for k,v in configs['ImportConfig'].items():
if k in ['DatasetConfig', 'ModelConfig', 'TrainerConfig'] and v:
tmp = load_yaml(v)
self.load_args_from_dict(tmp)
if hasattr(self, "seed") and self.seed:
t.random.torch.manual_seed(self.seed)
def load_args_from_dict(self, args):
for k,v in args.items():
setattr(self, k, v)
return args
def load_args(self, args):
self.load_args_from_dict(vars(args))
def _parse(self, **kwargs):
state_dict = self._state_dict()
for k, v in kwargs.items():
if self.print_newopt and k not in state_dict:
print('add new option: "%s = %s"' % (k, v))
setattr(self, k, v)
def _state_dict(self):
orig = [k for k in Config.__dict__.keys() if not k.startswith("_")]
keys = list(set(self.__dict__.keys()).union(set(orig)))
return {k: getattr(self, k) for k in keys}
@property
def device(self):
device = t.device(f"cuda:{self._gpu_idx}" \
if t.cuda.is_available() and self.use_cuda else "cpu")
return device
@property
def device_count(self):
return t.cuda.device_count()
@device_count.setter
def device_count(self, idx):
pass
@property
def gpu_idx(self):
return self._gpu_idx
@gpu_idx.setter
def gpu_idx(self, idx):
self._gpu_idx = idx
if self.use_cuda:
t.cuda.set_device(idx)
def __str__(self):
orig = [k for k in Config.__dict__.keys() if not k.startswith("_")]
keys = list(set(self.__dict__.keys()).union(set(orig)))
keys.sort()
s = ""
for k in keys:
if callable(getattr(self, k)) or k.startswith("_"): continue
s += f"{k}:{getattr(self, k)}\n"
return s
#%% if want to use multi-gpus for distributed calculate, you can use this class
# e.g.:
# CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
class Config_mgpus(Config):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
t.distributed.init_process_group(backend="nccl")
t.cuda.set_device(self.gpu_idx)
except Exception as e:
traceback.print_exc()
if self.gpu_idx!=0:
self.use_tsboard = False
self._lr = self.lr / self.device_count
self.is_parallel = True
@property
def device(self):
return t.device("cuda", self.gpu_idx)
@property
def gpu_idx(self):
return t.distributed.get_rank()
@property
def use_cuda(self):
return True
@gpu_idx.setter
def gpu_idx(self, idx):
print('muti-gpus config does not support gpu_idx setter')
@property
def lr(self):
return self._lr * self.device_count
@lr.setter
def lr(self, lr):
self._lr = lr
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Ky1eYang/ezdlearn.git
[email protected]:Ky1eYang/ezdlearn.git
Ky1eYang
ezdlearn
ezdlearn
master

搜索帮助