1 Star 0 Fork 13

ArriettyTrader/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_choose.py 9.96 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 . 1.6
import pandas as pd
import cvt_helper as cvt
from stock_holder import StockHolder
from stock_info import StockSceneInfo
from stock_info import StockPredictInfo
import const as cst
from data_center import DataCenter
from stable_baselines3 import SAC,PPO
from MyTT import *
import optuna
import random
import math
MODEL_NAME = "best_model_11"
class StockChoose():
def get_choose(self,all_stock,current_step):
pass
class DynamicStockChoose(StockChoose):
def __init__(self,min_increment,max_increment,ma_trend,macd_trend,yesterday_trend,kdj_trend):
self.min_increment = min_increment
self.max_increment = max_increment
self.yesterday_trend = yesterday_trend
self.ma_trend = ma_trend
self.macd_trend = macd_trend
self.kdj_trend = kdj_trend
self.tt_data = dict()
def get_choose(self,all_stock,current_step):
result = []
for key,value in all_stock.items () :
if self.on_choose(key,value,current_step):
result.append(key)
return result
def on_choose(self,code,stock,current_step):
today_info:StockSceneInfo = stock[current_step-1]
increment = today_info.Close/today_info.Standard
if increment < self.min_increment or increment > self.max_increment:
return False
yesterday_info:StockSceneInfo = stock[current_step-2]
if self.yesterday_trend==1 and yesterday_info.Close < yesterday_info.Standard:
return False
if self.yesterday_trend==-1 and yesterday_info.Close > yesterday_info.Standard:
return False
tt_data = self.get_tt_data(code,stock)
if self.macd_trend == 1:
macd = tt_data['macd'][current_step-1]
if macd < 0:
return False
if self.ma_trend == 1:
ma_5 = tt_data['ma_5'][current_step-1]
if today_info.Close < ma_5:
return False
if self.ma_trend == 2:
ma_10 = tt_data['ma_10'][current_step-1]
if today_info.Close < ma_10:
return False
if self.ma_trend == 3:
ma_20 = tt_data['ma_20'][current_step-1]
if today_info.Close < ma_20:
return False
if self.ma_trend == 4:
ma_30 = tt_data['ma_30'][current_step-1]
if today_info.Close < ma_30:
return False
if self.ma_trend == 3:
ma_60 = tt_data['ma_60'][current_step-1]
if today_info.Close < ma_60:
return False
if self.kdj_trend == 1:
kdj_k = tt_data['kdj_k'][current_step-1]
kdj_d = tt_data['kdj_d'][current_step-1]
if kdj_k < kdj_d:
return False
return True
def get_tt_data(self,code,stock_list):
if code in self.tt_data:
return self.tt_data[code]
close_series = []
high_series = []
low_series = []
for i in range(len(stock_list)):
stock:StockSceneInfo = stock_list[i]
close_series.append(np.float32(stock.Close))
high_series.append(np.float32(stock.High))
low_series.append(np.float32(stock.Low))
close_series = pd.Series(close_series, copy=False)
k,d,j = KDJ(close_series,high_series,low_series)
dif,dea,macd = MACD(close_series)
ma_5 = MA(close_series,5)
ma_10 = MA(close_series,10)
ma_20 = MA(close_series,20)
ma_30 = MA(close_series,30)
ma_60 = MA(close_series,60)
self.tt_data[code]={'macd':macd,'ma_5':ma_5,'ma_10':ma_10,'ma_20':ma_20,'ma_30':ma_30,'ma_60':ma_60,'kdj_k':k,'kdj_d':d}
return self.tt_data[code]
class StaticStockChoose(StockChoose):
def __init__(self,protfolie_stock):
self.protfolie_stock = protfolie_stock
def get_choose(self,all_stock,current_step):
result = []
for val in self.protfolie_stock:
if val in all_stock:
result.append(val)
return result
class StockPortfolio():
def __init__(self,model_path,is_safe,stock_choose):
self.current_step = cst.HISTORY_DATA_COUNT
self.is_safe = is_safe
self.model = SAC.load(model_path)
self.choose :StockChoose = stock_choose
self.data_center = DataCenter()
self.buy_dif = 0
self.sell_dif = 0
self.buy_count = 0
def run(self,start_date,zs_type):
protfolie_stock = self.data_center.query_concern_stock(start_date,zs_type)
if protfolie_stock is None :
print('run error : protfolie_stock is none')
return 0
self.all_stock = self.data_center.get_protfolie_info(protfolie_stock,start_date)
self.current_step = 0
self.stock_holder = StockHolder(self.is_safe,False,True)
while True:
if self._step():
break
print(f'buy dif {self.buy_dif} sell dif {self.sell_dif} buy count {self.buy_count}')
last_money = self.stock_holder.get_fortune()
return last_money/cst.INITIAL_CAPITAL
def _step(self):
sell_stock = self.stock_holder.stock_info.copy()
for key,value in sell_stock.items () :
stock_info = self.all_stock[value.StockCode]
last_info = self._get_last_info(stock_info)
if last_info is None :
return True
frame_stock_data = self._get_frame_data(stock_info)
self.stock_holder.update(key,last_info)
obs = cvt.get_obs(frame_stock_data)
action, _ = self.model.predict(obs,deterministic=True)
buy_price,sell_price = cvt.parse_action(last_info.Standard,action)
self.stock_holder.sell(key,sell_price)
buy_stock = self.choose.get_choose(self.all_stock,self.current_step)
if len(buy_stock)<=0:
return True
for value in buy_stock:
stock_info = self.all_stock[value]
last_info = self._get_last_info(stock_info)
if last_info is None :
return True
frame_stock_data = self._get_frame_data(stock_info)
obs = cvt.get_obs(frame_stock_data)
action, _ = self.model.predict(obs,deterministic=True)
buy_price,sell_price = cvt.parse_action(last_info.Standard,action)
if self.stock_holder.buy(value,last_info,buy_price):
self.buy_count += 1
self.buy_dif += math.fabs(buy_price-last_info.Low)
self.sell_dif += math.fabs(sell_price-last_info.High)
print(f'{sell_price/buy_price} {sell_price/last_info.Standard} {buy_price/last_info.Standard}')
self.current_step += 1
return False
def _print(self,stock_code):
print(f'{stock_code} Money : {self.stock_holder.get_fortune()},Safe Money : {self.stock_holder.safe_money}')
pass
def _get_frame_data(self,current_stock_data):
scene_info = []
for i in range(cst.HISTORY_DATA_COUNT):
scene : StockSceneInfo = current_stock_data[self.current_step-cst.HISTORY_DATA_COUNT+i]
scene_info.append(scene)
return scene_info
def _get_last_info(self,current_stock_data):
if self.current_step>=len(current_stock_data):
return None
last_info:StockSceneInfo = current_stock_data[self.current_step]
return last_info
def optimize_agent(trial):
try:
params = {
'min_increment':trial.suggest_uniform('min_increment', 0.9, 1.07),
'max_increment':trial.suggest_uniform('max_increment', 0.93, 1.1),
'yesterday_trend':trial.suggest_categorical("yesterday_trend", [-1, 0, 1]),
'macd_trend':trial.suggest_categorical("macd_trend", [0, 1]),
'ma_trend':trial.suggest_categorical("ma_trend", [0, 1,2,3,4,5]),
'kdj_trend':trial.suggest_categorical("kdj_trend", [0, 1]),
}
if params['min_increment']>=params['max_increment']:
return -1000
choose = DynamicStockChoose(**params)
test = StockPortfolio("./model/"+MODEL_NAME,False,choose)
reward = test.run("2022-01-01","zz500")
print("reward",reward)
return reward
except Exception as e:
print(e)
return -10000
def optimize_dynamic_param():
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=100)
print(study.best_params)
#Trial 36 finished with value: 1.4075005326522951 and parameters: {'min_increment': 0.942136512292944, 'max_increment': 1.0535419284664314, 'yesterday_trend': 0}. Best is trial 36 with value: 1.4075005326522951.
def run_dynamic_model():
params = {'min_increment': 0.9672961947176241, 'max_increment': 1.0359198462706576, 'yesterday_trend': 0, 'macd_trend': 0, 'ma_trend': 3, 'kdj_trend': 0}
choose = DynamicStockChoose(**params)
test = StockPortfolio("./model/"+MODEL_NAME,False,choose)
reward = test.run("2022-01-01","zz500")
print("reward",reward)
def run_staic_model():
#中证500 权重
protfolie_stock = ["sh.600522","sh.601615","sh.600256","sh.600157","sz.002407","sz.000009","sz.002180","sz.002384"]
#沪深300 权重
#protfolie_stock = ["sh.600519","sh.601318","sh.600036","sh.601012","sz.000858","sz.002594","sz.000333","sh.600900","sh.601166"]
#上证50 权重
#protfolie_stock = ["sh.600519","sh.601398","sh.601288","sh.601857","sh.600036","sh.601988","sh.601628","sh.600900","sh.601088","sh.601012"]
#zz500方差最小10大股份
#protfolie_stock = ["sh.601880","sz.002936","sh.600022","sh.600567","sh.600871","sh.601860","sh.601005","sz.002958","sz.002948","sh.601992"]
#protfolie_stock = ["sh.600256"]
choose = StaticStockChoose(protfolie_stock)
test = StockPortfolio("./model/"+MODEL_NAME,False,choose)
reward = test.run("2022-01-01","zz500")
print(MODEL_NAME+" reward",reward)
if __name__ == "__main__":
#optimize_dynamic_param()
#run_dynamic_model()
run_staic_model()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ArriettyTrader/stock_robot.git
[email protected]:ArriettyTrader/stock_robot.git
ArriettyTrader
stock_robot
stock_robot
master

搜索帮助