代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/future_agent 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
from training_env import TrainingEnv
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.noise import NormalActionNoise
import torch as th
import numpy as np
import optuna
from stable_baselines3.common.evaluation import evaluate_policy
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/sac"
LEARN_TIMES = 100000
TRAINING_BEGIN_TIME = ["2022-08-14","2022-08-15"
,"2022-08-18","2022-08-19","2022-08-20","2022-08-21","2022-08-22","2022-08-25","2022-08-26","2022-08-27","2022-08-28"
,"2022-08-29","2022-09-01","2022-09-02","2022-09-03","2022-09-04"]
# The algorithms require a vectorized environment to run
EVALUATE_BEGIN_TIME = ["2022-08-29"]
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
def optimize_params(trial,actions):
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[
trial.suggest_int('na_p1', 128, 1024),
trial.suggest_int('na_p2', 128, 1024),
trial.suggest_int('na_p3', 128, 1024),
trial.suggest_int('na_p4', 128, 1024),
trial.suggest_int('na_p5', 128, 1024),
trial.suggest_int('na_p6', 128, 1024),
trial.suggest_int('na_p7', 128, 1024),
trial.suggest_int('na_p8', 128, 1024)
]
)
sigma=trial.suggest_uniform('sigma', 0.01, 0.2)
return {
'gamma':trial.suggest_loguniform('gamma', 0.8, 0.99),
'batch_size':trial.suggest_categorical("batch_size", [16, 32, 64, 128, 256, 512, 1024, 2048]),
'buffer_size' : trial.suggest_categorical("buffer_size", [int(500000), int(1000000), int(2000000)]),
'learning_starts':trial.suggest_categorical("learning_starts", [1, 10, 100, 200, 2000]),
'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
'tau':trial.suggest_categorical("tau", [0.001, 0.005, 0.01, 0.02, 0.05, 0.08, 0.1, 0.2]),
'action_noise':NormalActionNoise(mean=np.zeros(actions), sigma=sigma * np.ones(actions)),
'train_freq' : trial.suggest_categorical("train_freq", [1, 4, 8, 16]),
'policy_kwargs':policy
}
def optimize_agent(trial):
try:
num_cpu = 32
# Create the vectorized environment
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_params(trial,env.action_space.shape[-1])
model = SAC('MlpPolicy', env,**model_params)
model.learn(total_timesteps=LEARN_TIMES)
model.save(MODEL_PATH+'/trial_{}'.format(trial.number))
mean_reward, _ = evaluate_policy(model, env, 2)
print("mean_reward",mean_reward)
return mean_reward
except Exception as e:
print(e)
return -10000
if __name__ == '__main__':
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=100)
print(study.best_params)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。