1 Star 0 Fork 14

ArriettyTrader/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
save_model.py 1.54 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 . 1.6
import os
import numpy as np
import threading
from stable_baselines3.common.monitor import Monitor
from stock_env import StockEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
class SaveModelCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq: (int)
:param log_dir: (str) Path to the folder where the model will be saved.
It must contains the file created by the ``Monitor`` wrapper.
:param verbose: (int)
"""
def __init__(self, check_freq, path ,verbose=1):
super(SaveModelCallback, self).__init__(verbose)
self.check_freq = check_freq
self.save_path = os.path.join(path, 'best_model')
self.best_mean_reward = -np.inf
self.env = Monitor(StockEnv([2022]))
self.lock = threading.RLock()
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
mean_reward, _ = evaluate_policy(self.model, self.env)
self.lock.acquire()
if mean_reward>=self.best_mean_reward :
self.best_mean_reward = mean_reward
self.model.save(self.save_path)
self.lock.release()
return True
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ArriettyTrader/stock_robot.git
[email protected]:ArriettyTrader/stock_robot.git
ArriettyTrader
stock_robot
stock_robot
master

搜索帮助