1 Star 0 Fork 13

ArriettyTrader/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
data_center.py 11.75 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 . 1.6
import pymysql # 连接mysql数据库的模块
import numpy as np
import baostock.common.contants as cons
import baostock as bs
import time
from stock_info import StockSceneInfo
from stock_info import StockPredictInfo
class DataCenter:
__bs_ref_num__ = 0
def __init__(self):
self.db_client = pymysql.connect(
host="127.0.0.1",
port=3306,
user="root",
password="111111",
database="stock_sb3_db",
charset='utf8mb4', # 一定不要写utf-8
# 针对增删改操作,执行重要程度高,若要执行,必须要有一步确认的操作,autocommit=True默认对数据库进行增删改操作时自动提交至数据库;若此处不写,在下面代码 cursor_obj.execute(sql) 后面手动需要添加 client.commit()操作
autocommit=True
)
'''
client = pymysql.connect(
host="rm-8vby6h31662bd2m29.mysql.zhangbei.rds.aliyuncs.com",
port=3306,
user="ghuazo",
password="%@Code1314",
database="stock_forecast",
charset='utf8mb4', # 一定不要写utf-8
# 针对增删改操作,执行重要程度高,若要执行,必须要有一步确认的操作,autocommit=True默认对数据库进行增删改操作时自动提交至数据库;若此处不写,在下面代码 cursor_obj.execute(sql) 后面手动需要添加 client.commit()操作
autocommit=True
)
'''
if DataCenter.__bs_ref_num__ == 0 :
bs.login()
DataCenter.__bs_ref_num__ += 1
def __del__(self):
# 关闭客户端连接
self.db_client.close()
DataCenter.__bs_ref_num__ -= 1
if DataCenter.__bs_ref_num__ == 0 :
bs.logout()
def _exe_sql(self,sql):
# print(client) # 打印结果:<pymysql.connections.Connection object at 0x00000000028A8B88>
# 2.获取游标对象。游标:可以用来提交sql命令
# pymysql.cursors.DictCursor:将查询出来的结果制作成字典的形式返回
cursor_obj = self.db_client.cursor(pymysql.cursors.DictCursor)
result = None
# 3.通过execute可以提交sql语句(增删改查都可以,可以使用try来捕获异常)
try:
# 2)提交sql语句
res = cursor_obj.execute(sql)
#print(res) # execute返回的是当前sql语句所影响的行数
# client.commit() # 上面连接数据库的代码处已写了autocommit=True,此处注释掉
# 3.1)提交后,通过cursor_obj对象.fetchall() 获取所有查询到的结果
result = cursor_obj.fetchall()
# 3.2).fetchone() 只获取查询结果中的一条
# print(cursor_obj.fetchone())
# 3.3).fetchmany() 可以指定获取几条数据
# print(cursor_obj.fetchmany(4)) # 获取四条数据
except Exception as e:
print(e)
# 关闭游标
cursor_obj.close()
return result
def get_start_date(self,code):
sql = "SELECT date FROM stock_daily_info WHERE code='"+code+"' ORDER BY date DESC LIMIT 1"
res = self._exe_sql(sql)
if res is not None and len(res)>0:
return res[0]['date'].strftime('%Y-%m-%d')
return None
def save_stock_info(self,code,stock_info):
if(len(stock_info)==0):
return
sql = "INSERT INTO stock_daily_info (code,date,open,high,low,close,stardand,deal,turn,is_st) VALUES "
for i in range(len(stock_info)):
it:StockSceneInfo = stock_info[i]
sql += "("
sql += "'"+code+"',"
sql += "'"+it.Date+"',"
sql += str(it.Open)+","
sql += str(it.High)+","
sql += str(it.Low)+","
sql += str(it.Close)+","
sql += str(it.Standard)+","
sql += str(it.Deal)+","
sql += str(it.Turn)+","
if it.IsST:
sql += "1"
else :
sql += "0"
sql += ")"
if i < len(stock_info)-1:
sql += ","
self._exe_sql(sql)
def get_stock_info(self,code,year):
sql = "SELECT * FROM stock_daily_info WHERE code='"+code+"' AND date>='"+str(year)+"-01-01' AND date<='"+str(year)+"-12-31' ORDER BY date ASC"
res = self._exe_sql(sql)
data_list = []
for i in range(len(res)):
scene : StockSceneInfo = StockSceneInfo()
item = res[i]
scene.Date = item['date']
scene.Open = np.float64(item['open'])
scene.Close = np.float64(item['close'])
scene.High = np.float64(item['high'])
scene.Low = np.float64(item['low'])
scene.Deal = np.float64(item['deal'])
scene.Standard = np.float64(item['stardand'])
scene.Turn = np.float64(item['turn'])
if item['is_st'] == 1:
scene.IsST = True
else :
scene.IsST = False
#print('--{scene.Date}-{scene.Open}-',scene.Date,scene.Open)
data_list.append(scene)
return data_list
def get_protfolie_info(self,code_list,start_date):
sql = "SELECT * FROM stock_daily_info WHERE code IN("
for i in range(len(code_list)):
sql += "'"+code_list[i]+"'"
if i < len(code_list)-1:
sql += ","
sql +=") AND date>='"+start_date+"' ORDER BY date ASC"
res = self._exe_sql(sql)
data_dict = dict()
for i in range(len(res)):
scene : StockSceneInfo = StockSceneInfo()
item = res[i]
scene.Date = item['date']
scene.Open = np.float64(item['open'])
scene.Close = np.float64(item['close'])
scene.High = np.float64(item['high'])
scene.Low = np.float64(item['low'])
scene.Deal = np.float64(item['deal'])
scene.Standard = np.float64(item['stardand'])
scene.Turn = np.float64(item['turn'])
if item['is_st'] == 1:
scene.IsST = True
else :
scene.IsST = False
#print('--{scene.Date}-{scene.Open}-',scene.Date,scene.Open)
if item['code'] not in data_dict.keys() :
data_dict[item['code']] = []
data_dict[item['code']].append(scene)
return data_dict
def save_predict_info(self,date,predict_list):
sql = "REPLACE INTO predict_info (code,date,buy_price,sell_price,expected) VALUES "
for i in range(len(predict_list)):
sql += "("
predict_info : StockPredictInfo = predict_list[i]
sql += "'"+predict_info.StockCode+"',"
sql += "'"+date+"',"
sql += str(predict_info.BuyPrice)+","
sql += str(predict_info.SellPrice)+","
sql += '0'
sql += ")"
if i < len(predict_list)-1:
sql += ","
self._exe_sql(sql)
def get_predict_date(self):
sql = "SELECT code,MAX(date) AS date FROM predict_info GROUP BY code"
res = self._exe_sql(sql)
data_dict = dict()
for i in range(len(res)):
item = res[i]
data_dict[item['code']] = item['date']
return data_dict
def query_all_stock(self,start_date = "2022-05-16"):
result = bs.query_all_stock(start_date)
all_stock = []
if result.error_code != cons.BSERR_SUCCESS :
print("bs.get_all_stock Error : {result.error_code} , {result.error_msg}")
return all_stock
for i in range(len(result.data)):
code = result.data[i][0]
if code.startswith("sh.60") or code.startswith("sz.00"):
all_stock.append(code)
return all_stock
def query_stock_info(self,code,start_date=None):
rs_list = []
rs_factor = bs.query_adjust_factor(code, "1990-01-01")
if rs_factor.error_code != cons.BSERR_SUCCESS :
print("bs.query_adjust_factor Error : {result.error_code} , {result.error_msg}")
return None
while rs_factor.next():
rs_list.append(rs_factor.get_row_data())
result = bs.query_history_k_data_plus(code,"date,code,open,high,low,close,preclose,volume,amount,turn,isST",start_date,adjustflag='2')
if result.error_code != cons.BSERR_SUCCESS :
print("bs.query_history_k_data_plus Error : {result.error_code} , {result.error_msg}")
return
data_list = []
skip = True
while result.next():
item = result.get_row_data()
if skip:
skip = False
continue
scene : StockSceneInfo = StockSceneInfo()
scene.Date = item[0]
scene.Open = np.float64(item[2])
scene.Close = np.float64(item[5])
scene.High = np.float64(item[3])
scene.Low = np.float64(item[4])
if item[7] != '' and np.float64(item[7]) != 0 :
adjust = self._adjust_factor(rs_list,item[0])
scene.Deal = np.float64(item[8])/np.float64(item[7])*adjust
scene.Standard = np.float64(item[6])
if not item[9] == '' :
scene.Turn = np.float64(item[9])
if item[10] == '1':
scene.IsST = True
else :
scene.IsST = False
#print('--{scene.Date}-{scene.Open}-',scene.Date,scene.Open)
data_list.append(scene)
return data_list
def _adjust_factor(self,current_adjust_factor,date):
factor = np.float64(1)
current_date = time.strptime(date,'%Y-%m-%d')
for i in range(len(current_adjust_factor)):
adjust_factor = current_adjust_factor[len(current_adjust_factor)-i-1]
adjust_date = time.strptime(adjust_factor[1],'%Y-%m-%d')
if current_date>=adjust_date:
factor = np.float64(adjust_factor[2])
break
return factor
def query_concern_stock(self,date="",type="hs300"):
if type=="hs300" :
# 获取沪深300成分股
rs = bs.query_hs300_stocks(date)
elif type=="zz500" :
rs = bs.query_zz500_stocks(date)
elif type=="sz50" :
rs = bs.query_sz50_stocks(date)
else :
print('get_concern_stock error type :'+type)
return
if rs.error_code != cons.BSERR_SUCCESS :
print('get_concern_stock error_msg: ' + rs.error_msg + "error_code: "+rs.error_code)
return
concern_stocks = []
while rs.next():
item = rs.get_row_data()
code = item[1]
if code.startswith("sh.60") or code.startswith("sz.00"):
concern_stocks.append(code)
return concern_stocks
def query_weighted_stock(self,start_date = "2022-05-16"):
rs = bs.query_hs300_stocks(start_date)
concern_stocks = []
while rs.next():
item = rs.get_row_data()
code = item[1]
if code.startswith("sh.60") or code.startswith("sz.00"):
concern_stocks.append(code)
rs = bs.query_zz500_stocks(start_date)
while rs.next():
item = rs.get_row_data()
code = item[1]
if code.startswith("sh.60") or code.startswith("sz.00"):
concern_stocks.append(code)
rs = bs.query_sz50_stocks(start_date)
while rs.next():
item = rs.get_row_data()
code = item[1]
if code.startswith("sh.60") or code.startswith("sz.00"):
concern_stocks.append(code)
return concern_stocks
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ArriettyTrader/stock_robot.git
[email protected]:ArriettyTrader/stock_robot.git
ArriettyTrader
stock_robot
stock_robot
master

搜索帮助