1 Star 0 Fork 24

happyhzq/future_agent

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
optimize_a2c.py 2.93 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-04 17:02 . 1.2.4
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import A2C
from training_env import TrainingEnv
from stable_baselines3.common.vec_env import SubprocVecEnv
import torch as th
import optuna
from stable_baselines3.common.evaluation import evaluate_policy
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/sac"
LEARN_TIMES = 50000
TRAINING_BEGIN_TIME = ["2022-08-14","2022-08-15"
,"2022-08-18","2022-08-19","2022-08-20","2022-08-21","2022-08-22","2022-08-25","2022-08-26","2022-08-27","2022-08-27"]
# The algorithms require a vectorized environment to run
EVALUATE_BEGIN_TIME = ["2022-08-29"]
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
def optimize_ppo(trial):
na_vf = [
trial.suggest_int('na_vf1', 1024, 8192),
trial.suggest_int('na_vf2', 1024, 8192)
]
na_pi = [
trial.suggest_int('na_pi1', 512, 4096),
trial.suggest_int('na_pi2', 512, 4096)
]
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[
trial.suggest_int('na_p1', 128, 1024),
trial.suggest_int('na_p2', 128, 1024),
trial.suggest_int('na_p3', 128, 1024),
trial.suggest_int('na_p4', 128, 1024),
trial.suggest_int('na_p5', 128, 1024),
trial.suggest_int('na_p6', 128, 1024),
trial.suggest_int('na_p7', 128, 1024),
trial.suggest_int('na_p8', 128, 1024)
,dict(vf=na_vf, pi=na_pi)]
)
return {
'n_steps':trial.suggest_int('n_steps', 2048, 8192),
'gamma':trial.suggest_loguniform('gamma', 0.8, 0.99),
'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99),
'policy_kwargs':policy
}
def optimize_agent(trial):
try:
num_cpu = 8
# Create the vectorized environment
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_ppo(trial)
model = A2C('MlpPolicy', env,**model_params)
model.learn(total_timesteps=LEARN_TIMES)
model.save(MODEL_PATH+'/trial_{}'.format(trial.number))
mean_reward, _ = evaluate_policy(model, env, 2)
print("mean_reward",mean_reward)
return mean_reward
except Exception as e:
print(e)
return -10000
if __name__ == '__main__':
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=50)
print(study.best_params)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/happyhzq/future_agent.git
[email protected]:happyhzq/future_agent.git
happyhzq
future_agent
future_agent
master

搜索帮助