1 Star 1 Fork 0

Henry_Fung/Alzheimer_Disease_Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_model.py 6.29 KB
一键复制 编辑 原始数据 按行查看 历史
fengjiahao 提交于 2021-08-10 18:33 . modify
import torch
from torch.utils.data import DataLoader
import os
import argparse
import torch_optimizer as optim
from tensorboardX import SummaryWriter
from utils.trainer import Trainer
import torch.nn as nn
from datasets.ad_ds import AD_Dataset, load_data
import os
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
import time
import copy
import warnings
import numpy as np
from models.loss.focal_loss import FocalLoss
from models.loss.contrastive_loss2 import Contrastive_Loss
warnings.filterwarnings("ignore")
if __name__ == "__main__":
# python -m visdom.server
# mp.set_start_method('spawn')
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=150, help="epoch")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning_rate")
parser.add_argument("--log_path", type=str, default='log/tensorboard/',
help="log_path")
parser.add_argument("--data_path", type=str, default='./train/train',
help="data_path")
parser.add_argument("--label_path", type=str, default=r'./train/train_open.csv',
help="label_path")
parser.add_argument("--data_url", type=str, default='',
help="data_url")
parser.add_argument("--train_url", type=str, default='',
help="train_url")
parser.add_argument("--log_url", type=str, default='',
help="log_url")
parser.add_argument("--init_method", type=str, default='',
help="init_method")
parser.add_argument("--save_name", type=str, default='dnn0_5layers_focal_1',
help="save_name")
parser.add_argument("--num_gpus", type=int, default=1,
help="num_gpus")
# parser.add_argument("--save_name", type=str, default='dnn_residual_focal_2',
# help="save_name")
# parser.add_argument("--save_name", type=str, default='dnn1_5_layers_focal_06242216',
# help="save_name")
# parser.add_argument("--model", type=str, default='cbam18',
# help="cbam18,resnet18,effnetb4,ca18,cbam34,cbam50")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
opt = parser.parse_args()
print(str(opt))
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
BATCH_SIZE = opt.batch_size
EPOCH = opt.epoch
learning_rate = opt.learning_rate
pretrain_w_path = opt.pretrain_weight_path if 'pretrain_weight_path' in opt else ''
n_samples = BATCH_SIZE * 20
writer = SummaryWriter(os.path.join(opt.log_path, opt.save_name), comment=opt.save_name,
flush_secs=2)
save_path = os.path.join('save', opt.save_name)
if not os.path.exists(opt.data_path):
import moxing as mox
mox.file.copy_parallel(opt.data_url, './train/')
print('数据已加载')
x, y = load_data(opt.data_path,opt.label_path)
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
mean = np.mean(x, axis=0)
std = np.std(x, axis=0)
x = (x - mean) / std
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021).split(x, y)
from models.dnn0 import DNN
# from models.dnn_1 import DNN
# from models.dnn_residual import DNN
init_model = DNN(28169, 4096, 512, 3, dropout_p=0.4)
# init_model = DNN(28169)
# sampler = WeightedRandomSampler(weights=train_data.sample_weight, num_samples=n_samples,
# replacement=True)
loss_list = []
max_auc_list = []
chkp_list = []
for fold, (trn_idx, val_idx) in enumerate(folds):
print('------------------Fold %i--------------------' % fold)
train_data = AD_Dataset(x, y,trn_idx,device)
val_data = AD_Dataset(x, y,val_idx,device)
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8,
pin_memory=True) # 使用DataLoader加载数据
model = copy.deepcopy(init_model)
model = nn.DataParallel(model)
model = model.to(device)
optimizer = optim.RAdam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
)
# criterion = nn.CrossEntropyLoss()
criterion = FocalLoss(gamma=2)
# criterion = Contrastive_Loss(smoothing_value=0.1)
trainer = Trainer(model, optimizer, criterion, train_data_loader, val_data_loader,device, epoch=EPOCH)
save_name = os.path.join(save_path,'f%i' % (fold))
if not os.path.exists(save_name):
os.makedirs(save_name)
min_val_loss, max_val_auc = trainer.train(save_name,fold)
# print('Fold' + str(fold), min_val_loss)
# print('Fold' + str(fold), max_val_auc)
# max_auc_list.append(max_val_auc)
# chkp_list.append(save_name)
if opt.train_url !='' :
if '/home/' not in opt.train_url:
import moxing as mox
# from deep_moxing.model_analysis.api import analyse, tmp_save
# model_path = 'obs://ad-competiton/my_baseline/model'
train_url = 'obs:'+opt.train_url.replace('s3:','')
data_url = 'obs:'+opt.data_url.replace('s3:','')
log_url = 'obs://'+opt.log_url
print('Start to save model to',train_url,'from',save_path)
np.save('./mean.npy', mean)
np.save('./std.npy', std)
mox.file.copy('./mean.npy', train_url + '/mean.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy_parallel(save_path,train_url)
mox.file.copy_parallel('./log',log_url)
else:
from shutil import copyfile
np.save(os.path.join(opt.train_url,'/mean.npy'), mean)
np.save(os.path.join(opt.train_url,'/std.npy'), std)
copyfile(save_path,opt.train_url)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/henry_fung/alzheimer_disease_classification.git
[email protected]:henry_fung/alzheimer_disease_classification.git
henry_fung
alzheimer_disease_classification
Alzheimer_Disease_Classification
master

搜索帮助