1 Star 0 Fork 16

张九经/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_with_choose.py 2.91 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 +08:00 . 1.6
import pandas as pd
import cvt_helper as cvt
from stock_holder import StockHolder
from stock_info import StockSceneInfo
import numpy as np
import const as cst
from data_center import DataCenter
from stable_baselines3 import SAC
import const as cst
import math
class StockPortfolioTest():
def __init__(self,model_path):
self.current_step = cst.HISTORY_DATA_COUNT
self.model = SAC.load(model_path)
self.data_center = DataCenter()
def run(self,start_date,zs_type):
self.protfolie_stock = self.data_center.query_concern_stock(start_date,zs_type)
#中证500 10大成分股
#self.protfolie_stock = ["sh.600522","sh.601615","sh.600256","sh.600157","sz.002407","sz.000009","sz.002180","sz.002384"]
#沪深300 10大成分股
#self.protfolie_stock = ["sh.600519","sh.601318","sh.600036","sh.601012","sz.000858","sz.002594","sz.000333","sh.600900","sh.601166"]
#上证 10大成分股
#self.protfolie_stock = ["sh.600519","sh.601398","sh.601288","sh.601857","sh.600036","sh.601988","sh.601628","sh.600900","sh.601088","sh.601012"]
if self.protfolie_stock is None :
print('run error : protfolie_stock is none')
return
result = dict()
stock_info = self.data_center.get_protfolie_info(self.protfolie_stock,start_date)
for key,value in stock_info.items () :
humbly = 0
self.current_step = 0
while not self._is_done(value):
humbly += self._step(value)
result[key] = humbly/len(value)
return result
def _step(self,stock_info):
frame_stock_data,last_info = self._get_frame_data(stock_info)
obs = cvt.get_obs(frame_stock_data)
action, _ = self.model.predict(obs,deterministic=True)
buy_price,sell_price = cvt.parse_action(last_info.Standard,action)
res = math.fabs(buy_price-last_info.Low)+math.fabs(sell_price-last_info.High)
self.current_step += 1
return res
def _is_done(self,stock_info):
max_step = len(stock_info)
return self.current_step>=max_step
def _get_frame_data(self,current_stock_data):
scene_info = []
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = current_stock_data[self.current_step-cst.HISTORY_DATA_COUNT+i]
scene_info.append(scene)
last_info:StockSceneInfo = current_stock_data[self.current_step]
return scene_info,last_info
if __name__ == "__main__":
model_name = "best_model_all_128c_2000_4"
test = StockPortfolioTest("./model/"+model_name)
result = test.run("2022-01-01","zz500")
result = sorted(result.items(),key=lambda x:x[1], reverse=False)
count = 10
if count > len(result):
count = len(result)
for i in range(count) :
print(f"{result[i][0]} {result[i][1]}")
#print(result)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/StepZ/stock_robot.git
[email protected]:StepZ/stock_robot.git
StepZ
stock_robot
stock_robot
master

搜索帮助