2 Star 30 Fork 13

高性能golang/ai_quant_demo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 1.92 KB
一键复制 编辑 原始数据 按行查看 历史
高性能golang 提交于 2022-12-08 16:10 . first commmit
# encoding: utf-8
import os
import pandas as pd
import lightgbm as lgb
from feature import FEATURE_DIR, LABEL_NAME
MODEL_DIR=os.path.join("file", "model")
MODEL_FILE = os.path.join(MODEL_DIR, "lgb.txt")
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
def train():
'''
训练lightGBM(决策树)模型
'''
train_corpus_list, test_corpus_list = [], []
for file in os.listdir(FEATURE_DIR):
df = pd.read_pickle(os.path.join(FEATURE_DIR, file)) # 从文件中加载样本数据
pivot = int(len(df) * 0.6) # 训练集和测试集六四开
train_corpus_list.append(df[:pivot])
test_corpus_list.append(df[pivot:])
train_corpus = pd.concat(train_corpus_list, axis=0)
test_corpus = pd.concat(test_corpus_list, axis=0)
feature_names = [ele for ele in train_corpus.columns if ele != LABEL_NAME]
train_x = train_corpus.loc[:, feature_names]
train_y = train_corpus.loc[:, LABEL_NAME]
test_x = test_corpus.loc[:, feature_names]
test_y = test_corpus.loc[:, LABEL_NAME]
# lightGBM默认把Nan当作缺失值
dtrain = lgb.Dataset(data=train_x, label=train_y, feature_name=feature_names,
params={'num_threads': 8, 'use_missing': True, 'zero_as_missing': False, 'verbose': 0}, )
dtest = lgb.Dataset(data=test_x, label=test_y, feature_name=feature_names,
params={'num_threads': 8, 'use_missing': True, 'zero_as_missing': False, 'verbose': 0})
gbm = lgb.train(params={'learning_rate': 0.01, # 学习率
'max_depth': 3, # 每棵树的最大深度。特征越多,深度应该越大
},
num_boost_round=10, # 多少棵树
train_set=dtrain, valid_sets=[dtrain, dtest],
)
# 保存模型文件,文件末尾记录了各特征的重要度
gbm.save_model(MODEL_FILE)
if __name__ == '__main__':
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/orisunzhang/ai_quant_demo.git
[email protected]:orisunzhang/ai_quant_demo.git
orisunzhang
ai_quant_demo
ai_quant_demo
master

搜索帮助