1 Star 0 Fork 14

张九经/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
runtime.py 3.00 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 . 1.6
import numpy as np
import cvt_helper as cvt
from stock_info import StockSceneInfo
from stock_info import StockPredictInfo
import const as cst
from data_center import DataCenter
from stable_baselines3 import SAC
class StockRuntime():
def __init__(self):
self.model = SAC.load("./model/best_model_3")
self.data_center = DataCenter()
def update(self,all_stock):
if len(all_stock)==0:
return False
split_stock = np.array_split(all_stock,30)
for i in range(len(split_stock)):
result = self._run_predict(split_stock[i])
for key,value in result.items () :
self.data_center.save_predict_info(key,value)
print(f"save_predict_info:{len(value)}")
return True
def recommend(self,date):
best_stock : StockPredictInfo = StockPredictInfo()
protfolie_stock = self.data_center.query_concern_stock("2022-05-15",'hs300')
result = self._run_predict(protfolie_stock)
if date not in result :
return None
for i in range(len(result[date])) :
value = result[date][i]
if value.ExpectedProfits < 2.2:
continue
print(f"recommend -> code:{value.StockCode}")
if value.ExpectedProfits >= best_stock.ExpectedProfits :
best_stock = value
return best_stock
def _run_predict(self,all_stock):
result = dict()
for i in range(len(all_stock)):
stock = all_stock[i]
scene,last_info = self._get_frame_data(stock)
if scene is None or last_info is None:
continue
obs = cvt.get_obs(scene)
action, _ = self.model.predict(obs,deterministic=True)
buy_price,sell_price = cvt.parse_action(last_info.Close,action)
predict_info:StockPredictInfo = StockPredictInfo(stock,buy_price,sell_price)
if last_info.Date not in result:
result[last_info.Date] = []
result[last_info.Date].append(predict_info) #
return result
def _get_frame_data(self,code):
scene_data = []
current_stock_data = self.data_center.query_stock_info(code,"2022-05-15")
if len(current_stock_data) < cst.HISTORY_DATA_COUNT :
return None,None
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = current_stock_data[len(current_stock_data)-cst.HISTORY_DATA_COUNT+i]
scene_data.append(scene)
last_info:StockSceneInfo = current_stock_data[len(current_stock_data)-1]
return scene_data,last_info
if __name__ == "__main__":
runtime = StockRuntime()
protfolie_stock = ["sh.600522","sh.601615","sh.600256","sh.600157","sz.002407","sz.000009","sz.002180","sz.002384"]
runtime.update(protfolie_stock)
#recommend = runtime.recommend('2022-06-23')
#print(f"recommend -> code:{recommend.StockCode}, buy:{recommend.BuyPrice}, sell:{recommend.SellPrice}, expected:{recommend.ExpectedProfits}%")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/StepZ/stock_robot.git
[email protected]:StepZ/stock_robot.git
StepZ
stock_robot
stock_robot
master

搜索帮助