代码拉取完成,页面将自动刷新
import os
import pickle
import pandas as pd
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
from rlenv.StockTradingEnv0 import StockTradingEnv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
font = fm.FontProperties(fname='font/wqy-microhei.ttc')
# plt.rc('font', family='Source Han Sans CN')
plt.rcParams['axes.unicode_minus'] = False
def stock_trade(stock_file):
day_profits = []
df = pd.read_csv(stock_file)
df = df.sort_values('date')
# The algorithms require a vectorized environment to run
env = DummyVecEnv([lambda: StockTradingEnv(df)])
model = PPO2(MlpPolicy, env, verbose=0, tensorboard_log='./log')
model.learn(total_timesteps=int(1e4))
df_test = pd.read_csv(stock_file.replace('train', 'test'))
env = DummyVecEnv([lambda: StockTradingEnv(df_test)])
obs = env.reset()
for i in range(len(df_test) - 1):
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
profit = env.render()
day_profits.append(profit)
if done:
break
return day_profits
def find_file(path, name):
# print(path, name)
for root, dirs, files in os.walk(path):
for fname in files:
if name in fname:
return os.path.join(root, fname)
def test_a_stock_trade(stock_code):
stock_file = find_file('./stockdata/train', str(stock_code))
daily_profits = stock_trade(stock_file)
fig, ax = plt.subplots()
ax.plot(daily_profits, '-o', label=stock_code, marker='o', ms=10, alpha=0.7, mfc='orange')
ax.grid()
plt.xlabel('step')
plt.ylabel('profit')
ax.legend(prop=font)
# plt.show()
plt.savefig(f'./img/{stock_code}.png')
def multi_stock_trade():
start_code = 600000
max_num = 3000
group_result = []
for code in range(start_code, start_code + max_num):
stock_file = find_file('./stockdata/train', str(code))
if stock_file:
try:
profits = stock_trade(stock_file)
group_result.append(profits)
except Exception as err:
print(err)
with open(f'code-{start_code}-{start_code + max_num}.pkl', 'wb') as f:
pickle.dump(group_result, f)
if __name__ == '__main__':
# multi_stock_trade()
test_a_stock_trade('sh.600036')
# ret = find_file('./stockdata/train', '600036')
# print(ret)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。