1 Star 1 Fork 0

Henry_Fung/Alzheimer_Disease_Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_model_svm.py 5.39 KB
一键复制 编辑 原始数据 按行查看 历史
fengjiahao 提交于 2021-08-09 18:59 . modify
import torch
import argparse
from tensorboardX import SummaryWriter
from datasets.ad_ds import AD_Dataset, load_data
import os
from sklearn.model_selection import StratifiedKFold
import warnings
import numpy as np
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit#分层洗牌分割交叉验证
from sklearn.model_selection import GridSearchCV
import pickle
# import joblib
warnings.filterwarnings("ignore")
if __name__ == "__main__":
# python -m visdom.server
# mp.set_start_method('spawn')
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=150, help="epoch")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning_rate")
parser.add_argument("--log_path", type=str, default='log/tensorboard/',
help="log_path")
parser.add_argument("--data_path", type=str, default='./train/train',
help="data_path")
parser.add_argument("--label_path", type=str, default=r'./train/train_open.csv',
help="label_path")
parser.add_argument("--data_url", type=str, default='',
help="data_url")
parser.add_argument("--train_url", type=str, default='',
help="train_url")
parser.add_argument("--log_url", type=str, default='',
help="log_url")
parser.add_argument("--init_method", type=str, default='',
help="init_method")
# parser.add_argument("--save_name", type=str, default='dnn0_5layers_focal',
# help="save_name")
# parser.add_argument("--save_name", type=str, default='dnn_residual_focal_2',
# help="save_name")
parser.add_argument("--save_name", type=str, default='lgbm_0721',
help="save_name")
parser.add_argument("--app_url", type=str, default='',
help="save_name")
parser.add_argument("--boot_file", type=str, default='',
help="save_name")
parser.add_argument("--log_file", type=str, default='',
help="save_name")
parser.add_argument("--num_gpus", type=int, default=1,
help="num_gpus")
# parser.add_argument("--model", type=str, default='cbam18',
# help="cbam18,resnet18,effnetb4,ca18,cbam34,cbam50")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
opt = parser.parse_args()
print(str(opt))
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
BATCH_SIZE = opt.batch_size
EPOCH = opt.epoch
learning_rate = opt.learning_rate
pretrain_w_path = opt.pretrain_weight_path if 'pretrain_weight_path' in opt else ''
n_samples = BATCH_SIZE * 20
writer = SummaryWriter(os.path.join(opt.log_path, opt.save_name), comment=opt.save_name,
flush_secs=2)
save_path = os.path.join('save', opt.save_name)
if not os.path.exists(opt.data_path):
import moxing as mox
mox.file.copy_parallel(opt.data_url, './train/')
print('数据已加载')
x, y = load_data(opt.data_path, opt.label_path)
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
mean = np.mean(x, axis=0)
std = np.std(x, axis=0)
x = (x - mean) / std
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021).split(x, y)
loss_list = []
max_auc_list = []
chkp_list = []
parameters = {
"estimator__C": [1, 2, 4],
"estimator__kernel": ["poly", "rbf"],
"estimator__degree": [1, 2, 3],
}
cv = StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=42)
model = OneVsRestClassifier(SVC(probability=True))
grid = GridSearchCV(model, param_grid=parameters, cv=cv,refit=True,
return_train_score=True) # 基于交叉验证的网格搜索。
grid.fit(x, y)
save_name = os.path.join(save_path)
if not os.path.exists(save_name):
os.makedirs(save_name)
model_save_path = os.path.join(save_name, 'svm.pkl')
output = open(model_save_path, 'wb')
pickle.dump(model, output)
output.close()
# for fold, (trn_idx, val_idx) in enumerate(folds):
# print('------------------Fold %i--------------------' % fold)
# model = lgb.LGBMClassifier()
# model.fit(x[trn_idx], y[trn_idx])
# result = model.predict(test_data.drop(['version'], axis=1))
if opt.train_url != '':
import moxing as mox
# from deep_moxing.model_analysis.api import analyse, tmp_save
# model_path = 'obs://ad-competiton/my_baseline/model'
train_url = 'obs:' + opt.train_url.replace('s3:', '')
data_url = 'obs:' + opt.data_url.replace('s3:', '')
log_url = 'obs://' + opt.log_url
print('Start to save model to', train_url, 'from', save_path)
np.save('./mean.npy', mean)
np.save('./std.npy', std)
mox.file.copy('./mean.npy', train_url + '/mean.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy_parallel(save_path, train_url)
# mox.file.copy_parallel('./log',log_url)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/henry_fung/alzheimer_disease_classification.git
[email protected]:henry_fung/alzheimer_disease_classification.git
henry_fung
alzheimer_disease_classification
Alzheimer_Disease_Classification
master

搜索帮助