代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/stock_robot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import pandas as pd
import cvt_helper as cvt
from stock_holder import StockHolder
from stock_info import StockSceneInfo
import numpy as np
import const as cst
from data_center import DataCenter
from stable_baselines3 import SAC
import const as cst
import math
class StockPortfolioTest():
def __init__(self,model_path):
self.current_step = cst.HISTORY_DATA_COUNT
self.model = SAC.load(model_path)
self.data_center = DataCenter()
def run(self,start_date,zs_type):
self.protfolie_stock = self.data_center.query_concern_stock(start_date,zs_type)
#中证500 10大成分股
#self.protfolie_stock = ["sh.600522","sh.601615","sh.600256","sh.600157","sz.002407","sz.000009","sz.002180","sz.002384"]
#沪深300 10大成分股
#self.protfolie_stock = ["sh.600519","sh.601318","sh.600036","sh.601012","sz.000858","sz.002594","sz.000333","sh.600900","sh.601166"]
#上证 10大成分股
#self.protfolie_stock = ["sh.600519","sh.601398","sh.601288","sh.601857","sh.600036","sh.601988","sh.601628","sh.600900","sh.601088","sh.601012"]
if self.protfolie_stock is None :
print('run error : protfolie_stock is none')
return
result = dict()
stock_info = self.data_center.get_protfolie_info(self.protfolie_stock,start_date)
for key,value in stock_info.items () :
humbly = 0
self.current_step = 0
while not self._is_done(value):
humbly += self._step(value)
result[key] = humbly/len(value)
return result
def _step(self,stock_info):
frame_stock_data,last_info = self._get_frame_data(stock_info)
obs = cvt.get_obs(frame_stock_data)
action, _ = self.model.predict(obs,deterministic=True)
buy_price,sell_price = cvt.parse_action(last_info.Standard,action)
res = math.fabs(buy_price-last_info.Low)+math.fabs(sell_price-last_info.High)
self.current_step += 1
return res
def _is_done(self,stock_info):
max_step = len(stock_info)
return self.current_step>=max_step
def _get_frame_data(self,current_stock_data):
scene_info = []
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = current_stock_data[self.current_step-cst.HISTORY_DATA_COUNT+i]
scene_info.append(scene)
last_info:StockSceneInfo = current_stock_data[self.current_step]
return scene_info,last_info
if __name__ == "__main__":
model_name = "best_model_all_128c_2000_4"
test = StockPortfolioTest("./model/"+model_name)
result = test.run("2022-01-01","zz500")
result = sorted(result.items(),key=lambda x:x[1], reverse=False)
count = 10
if count > len(result):
count = len(result)
for i in range(count) :
print(f"{result[i][0]} {result[i][1]}")
#print(result)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。