代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/stock_robot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import numpy as np
import threading
from stable_baselines3.common.monitor import Monitor
from stock_env import StockEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
class SaveModelCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq: (int)
:param log_dir: (str) Path to the folder where the model will be saved.
It must contains the file created by the ``Monitor`` wrapper.
:param verbose: (int)
"""
def __init__(self, check_freq, path ,verbose=1):
super(SaveModelCallback, self).__init__(verbose)
self.check_freq = check_freq
self.save_path = os.path.join(path, 'best_model')
self.best_mean_reward = -np.inf
self.env = Monitor(StockEnv([2022]))
self.lock = threading.RLock()
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
mean_reward, _ = evaluate_policy(self.model, self.env)
self.lock.acquire()
if mean_reward>=self.best_mean_reward :
self.best_mean_reward = mean_reward
self.model.save(self.save_path)
self.lock.release()
return True
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。