代码拉取完成,页面将自动刷新
import torch
import argparse
from tensorboardX import SummaryWriter
from datasets.ad_ds import AD_Dataset, load_data
import os
from sklearn.model_selection import StratifiedKFold
import warnings
import numpy as np
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit#分层洗牌分割交叉验证
from sklearn.model_selection import GridSearchCV
import pickle
# import joblib
warnings.filterwarnings("ignore")
if __name__ == "__main__":
# python -m visdom.server
# mp.set_start_method('spawn')
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=150, help="epoch")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning_rate")
parser.add_argument("--log_path", type=str, default='log/tensorboard/',
help="log_path")
parser.add_argument("--data_path", type=str, default='./train/train',
help="data_path")
parser.add_argument("--label_path", type=str, default=r'./train/train_open.csv',
help="label_path")
parser.add_argument("--data_url", type=str, default='',
help="data_url")
parser.add_argument("--train_url", type=str, default='',
help="train_url")
parser.add_argument("--log_url", type=str, default='',
help="log_url")
parser.add_argument("--init_method", type=str, default='',
help="init_method")
# parser.add_argument("--save_name", type=str, default='dnn0_5layers_focal',
# help="save_name")
# parser.add_argument("--save_name", type=str, default='dnn_residual_focal_2',
# help="save_name")
parser.add_argument("--save_name", type=str, default='lgbm_0721',
help="save_name")
parser.add_argument("--app_url", type=str, default='',
help="save_name")
parser.add_argument("--boot_file", type=str, default='',
help="save_name")
parser.add_argument("--log_file", type=str, default='',
help="save_name")
parser.add_argument("--num_gpus", type=int, default=1,
help="num_gpus")
# parser.add_argument("--model", type=str, default='cbam18',
# help="cbam18,resnet18,effnetb4,ca18,cbam34,cbam50")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
opt = parser.parse_args()
print(str(opt))
# os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
BATCH_SIZE = opt.batch_size
EPOCH = opt.epoch
learning_rate = opt.learning_rate
pretrain_w_path = opt.pretrain_weight_path if 'pretrain_weight_path' in opt else ''
n_samples = BATCH_SIZE * 20
writer = SummaryWriter(os.path.join(opt.log_path, opt.save_name), comment=opt.save_name,
flush_secs=2)
save_path = os.path.join('save', opt.save_name)
if not os.path.exists(opt.data_path):
import moxing as mox
mox.file.copy_parallel(opt.data_url, './train/')
print('数据已加载')
x, y = load_data(opt.data_path, opt.label_path)
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
mean = np.mean(x, axis=0)
std = np.std(x, axis=0)
x = (x - mean) / std
x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021).split(x, y)
loss_list = []
max_auc_list = []
chkp_list = []
parameters = {
"estimator__C": [1, 2, 4],
"estimator__kernel": ["poly", "rbf"],
"estimator__degree": [1, 2, 3],
}
cv = StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=42)
model = OneVsRestClassifier(SVC(probability=True))
grid = GridSearchCV(model, param_grid=parameters, cv=cv,refit=True,
return_train_score=True) # 基于交叉验证的网格搜索。
grid.fit(x, y)
save_name = os.path.join(save_path)
if not os.path.exists(save_name):
os.makedirs(save_name)
model_save_path = os.path.join(save_name, 'svm.pkl')
output = open(model_save_path, 'wb')
pickle.dump(model, output)
output.close()
# for fold, (trn_idx, val_idx) in enumerate(folds):
# print('------------------Fold %i--------------------' % fold)
# model = lgb.LGBMClassifier()
# model.fit(x[trn_idx], y[trn_idx])
# result = model.predict(test_data.drop(['version'], axis=1))
if opt.train_url != '':
import moxing as mox
# from deep_moxing.model_analysis.api import analyse, tmp_save
# model_path = 'obs://ad-competiton/my_baseline/model'
train_url = 'obs:' + opt.train_url.replace('s3:', '')
data_url = 'obs:' + opt.data_url.replace('s3:', '')
log_url = 'obs://' + opt.log_url
print('Start to save model to', train_url, 'from', save_path)
np.save('./mean.npy', mean)
np.save('./std.npy', std)
mox.file.copy('./mean.npy', train_url + '/mean.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy('./std.npy', train_url + '/std.npy')
mox.file.copy_parallel(save_path, train_url)
# mox.file.copy_parallel('./log',log_url)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。