1 Star 1 Fork 0

Henry_Fung/Alzheimer_Disease_Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_model_catboost.py 5.79 KB
一键复制 编辑 原始数据 按行查看 历史
Henry_Fung 提交于 2021-08-08 16:44 . add fusion
import torch
import argparse
from tensorboardX import SummaryWriter
from datasets.ad_ds import AD_Dataset, load_data
import os
from sklearn.model_selection import StratifiedKFold
import warnings
import numpy as np
from catboost import CatBoostClassifier
# import pickle
# import joblib
warnings.filterwarnings("ignore")
if __name__ == "__main__":
# python -m visdom.server
# mp.set_start_method('spawn')
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=150, help="epoch")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning_rate")
parser.add_argument("--log_path", type=str, default='log/tensorboard/',
help="log_path")
parser.add_argument("--data_path", type=str, default='./train/train',
help="data_path")
parser.add_argument("--label_path", type=str, default=r'./train/train_open.csv',
help="label_path")
parser.add_argument("--data_url", type=str, default='',
help="data_url")
parser.add_argument("--train_url", type=str, default='',
help="train_url")
parser.add_argument("--log_url", type=str, default='',
help="log_url")
parser.add_argument("--init_method", type=str, default='',
help="init_method")
# parser.add_argument("--save_name", type=str, default='dnn0_5layers_focal',
# help="save_name")
# parser.add_argument("--save_name", type=str, default='dnn_residual_focal_2',
# help="save_name")
parser.add_argument("--save_name", type=str, default='lgbm_0721',
help="save_name")
parser.add_argument("--app_url", type=str, default='',
help="save_name")
parser.add_argument("--boot_file", type=str, default='',
help="save_name")
parser.add_argument("--log_file", type=str, default='',
help="save_name")
parser.add_argument("--num_gpus", type=int, default=1,
help="num_gpus")
# parser.add_argument("--model", type=str, default='cbam18',
# help="cbam18,resnet18,effnetb4,ca18,cbam34,cbam50")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
opt = parser.parse_args()
print(str(opt))
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
BATCH_SIZE = opt.batch_size
EPOCH = opt.epoch
learning_rate = opt.learning_rate
pretrain_w_path = opt.pretrain_weight_path if 'pretrain_weight_path' in opt else ''
n_samples = BATCH_SIZE * 20
writer = SummaryWriter(os.path.join(opt.log_path, opt.save_name), comment=opt.save_name,
flush_secs=2)
save_path = os.path.join('save', opt.save_name)
if not os.path.exists(opt.data_path):
import moxing as mox
mox.file.copy_parallel(opt.data_url, './train/')
print('数据已加载')
x, y = load_data(opt.data_path, opt.label_path)
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
mean = np.mean(x, axis=0)
std = np.std(x, axis=0)
x = (x - mean) / std
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021).split(x, y)
loss_list = []
max_auc_list = []
chkp_list = []
# 0.6707
# model = CatBoostClassifier(iterations=1000,
# task_type="GPU")
# model = CatBoostClassifier(
# loss_function='MultiClass',
# eval_metric="AUC",
# task_type="GPU",
# # learning_rate=0.01,
# iterations=6000,
# od_type="Iter",
# # depth=4,
# early_stopping_rounds=500,
# # l2_leaf_reg=10,
# # border_count=96,
# random_seed=42,
# # use_best_model=use_best_model
# )
model = CatBoostClassifier(
loss_function='MultiClass',
eval_metric="TotalF1",
task_type="GPU",
# learning_rate=0.01,
iterations=6000,
od_type="Iter",
depth=8,
early_stopping_rounds=500,
# l2_leaf_reg=10,
# border_count=96,
random_seed=42,
# use_best_model=use_best_model
)
model.fit(x, y)
save_name = os.path.join(save_path)
if not os.path.exists(save_name):
os.makedirs(save_name)
model_save_path = os.path.join(save_name, 'catboost.model')
# output = open(model_save_path, 'wb')
# pickle.dump(model, output)
# output.close()
# joblib.dump(model, model_save_path)
#
# model.save_model('model.txt')
model.save_model(model_save_path)
# for fold, (trn_idx, val_idx) in enumerate(folds):
# print('------------------Fold %i--------------------' % fold)
# model = lgb.LGBMClassifier()
# model.fit(x[trn_idx], y[trn_idx])
# result = model.predict(test_data.drop(['version'], axis=1))
if opt.train_url != '':
import moxing as mox
# from deep_moxing.model_analysis.api import analyse, tmp_save
# model_path = 'obs://ad-competiton/my_baseline/model'
train_url = 'obs:' + opt.train_url.replace('s3:', '')
data_url = 'obs:' + opt.data_url.replace('s3:', '')
log_url = 'obs://' + opt.log_url
print('Start to save model to', train_url, 'from', save_path)
np.save('./mean.npy', mean)
np.save('./std.npy', std)
mox.file.copy('./mean.npy', train_url + '/mean.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy_parallel(save_path, train_url)
# mox.file.copy_parallel('./log',log_url)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/henry_fung/alzheimer_disease_classification.git
[email protected]:henry_fung/alzheimer_disease_classification.git
henry_fung
alzheimer_disease_classification
Alzheimer_Disease_Classification
master

搜索帮助