1 Star 0 Fork 246

xfyzs/faiss_dog_cat_question_1

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 7.56 KB
一键复制 编辑 原始数据 按行查看 历史
xfyzs 提交于 2024-11-08 20:08 +08:00 . 1
import numpy as np
from util import createXY
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, BaggingClassifier, AdaBoostClassifier, GradientBoostingClassifier, StackingClassifier
from sklearn.svm import SVC
import logging
from tqdm import tqdm
import time
# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 主函数,运行训练过程
def main():
# 固定参数
mode = 'cpu'
feature = 'flat'
library = 'sklearn'
logging.info(f"选择模式是 {mode.upper()}")
logging.info(f"选择特征提取方法是 {feature.upper()}")
logging.info(f"选择使用的库是 {library.upper()}")
# 载入和预处理数据
X, y = createXY(train_folder=r"C:\Users\86187\Desktop\cat_dog_data\data\train", dest_folder=".", method=feature)
X = np.array(X).astype('float32')
y = np.array(y)
logging.info("数据加载和预处理完成。")
# 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
logging.info("数据集划分为训练集和测试集。")
# 训练逻辑回归模型
start_time = time.time()
logistic_regression = LogisticRegression(max_iter=1000)
logistic_regression.fit(X_train, y_train)
logging.info(f"logistic_regression模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估逻辑回归模型
start_time = time.time()
accuracy = logistic_regression.score(X_test, y_test)
logging.info(f"logistic_regression模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'逻辑回归模型准确率: {accuracy:.4f}')
# 训练随机森林模型
start_time = time.time()
random_forest = RandomForestClassifier(n_estimators=100)
random_forest.fit(X_train, y_train)
logging.info(f"random forest模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估随机森林模型
start_time = time.time()
accuracy = random_forest.score(X_test, y_test)
logging.info(f"random forest模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'随机森林模型准确率: {accuracy:.4f}')
# 训练SVM模型
start_time = time.time()
svm = SVC()
svm.fit(X_train, y_train)
logging.info(f"svm模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估SVM模型
start_time = time.time()
accuracy = svm.score(X_test, y_test)
logging.info(f"svm模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'SVM模型准确率: {accuracy:.4f}')
# 训练硬投票模型
start_time = time.time()
hard_voting = VotingClassifier(estimators=[
('rf', RandomForestClassifier(n_estimators=100)),
('svm', SVC())])
hard_voting.fit(X_train, y_train)
logging.info(f"hard_voting模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估硬投票模型
start_time = time.time()
accuracy = hard_voting.score(X_test, y_test)
logging.info(f"hard_voting模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'硬投票模型准确率: {accuracy:.4f}')
# 训练软投票模型
start_time = time.time()
soft_voting = VotingClassifier(estimators=[
('rf', RandomForestClassifier(n_estimators=100)),
('svm', SVC())])
soft_voting.fit(X_train, y_train)
logging.info(f"soft_voting模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估软投票模型
start_time = time.time()
accuracy = soft_voting.score(X_test, y_test)
logging.info(f"soft_voting模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'软投票模型准确率: {accuracy:.4f}')
# 训练Bagging模型
start_time = time.time()
bagging = BaggingClassifier(estimator=RandomForestClassifier(n_estimators=100), n_estimators=10)
bagging.fit(X_train, y_train)
logging.info(f"bagging模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估Bagging模型
start_time = time.time()
accuracy = bagging.score(X_test, y_test)
logging.info(f"bagging模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'Bagging模型准确率: {accuracy:.4f}')
# 训练Pasting模型(这里使用Bagging作为示例,因为Pasting不是sklearn的标准模型)
start_time = time.time()
pasting = BaggingClassifier(estimator=RandomForestClassifier(n_estimators=100), n_estimators=10)
pasting.fit(X_train, y_train)
logging.info(f"pasting模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估Pasting模型
start_time = time.time()
accuracy = pasting.score(X_test, y_test)
logging.info(f"pasting模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'Pasting模型准确率: {accuracy:.4f}')
# 训练AdaBoost模型
start_time = time.time()
adaboost = AdaBoostClassifier(estimator=RandomForestClassifier(n_estimators=100), n_estimators=10)
adaboost.fit(X_train, y_train)
logging.info(f"adaboost模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估AdaBoost模型
start_time = time.time()
accuracy = adaboost.score(X_test, y_test)
logging.info(f"adaboost模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'AdaBoost模型准确率: {accuracy:.4f}')
# 训练Gradient Boosting模型
start_time = time.time()
gradient_boosting = GradientBoostingClassifier(n_estimators=100)
gradient_boosting.fit(X_train, y_train)
logging.info(f"gradient_boosting模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估Gradient Boosting模型
start_time = time.time()
accuracy = gradient_boosting.score(X_test, y_test)
logging.info(f"gradient_boosting模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'Gradient Boosting模型准确率: {accuracy:.4f}')
# 训练Stacking模型
start_time = time.time()
stacking = StackingClassifier(estimators=[
('rf', RandomForestClassifier(n_estimators=100)),
('svm', SVC())], final_estimator=LogisticRegression())
stacking.fit(X_train, y_train)
logging.info(f"stacking模型训练完成。用时{time.time() - start_time:.4f}秒")
# 评估Stacking模型
start_time = time.time()
accuracy = stacking.score(X_test, y_test)
logging.info(f"stacking模型评估完成。用时{time.time() - start_time:.4f}秒")
logging.info(f'Stacking模型准确率: {accuracy:.4f}')
# 初始化变量,跟踪最佳的k值和相应的准确率
best_k = -1
best_accuracy = 0.0
# 定义测试的k值范围
k_values = range(1, 6)
# 根据提供的库选择K近邻算法实现
KNNClass = KNeighborsClassifier
logging.info(f"使用的库为: {library.upper()}")
# 遍历k
# 遍历k值,训练并评估模型
for k in tqdm(k_values, desc='寻找最佳k值'):
knn = KNNClass(n_neighbors=k)
knn.fit(X_train, y_train)
accuracy = knn.score(X_test, y_test)
# 更新最佳k值和准确率
if accuracy > best_accuracy:
best_k = k
best_accuracy = accuracy
# 打印结果
logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')
# 如果是主脚本,则执行main函数
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xfyzs/faiss_dog_cat_question_1.git
[email protected]:xfyzs/faiss_dog_cat_question_1.git
xfyzs
faiss_dog_cat_question_1
faiss_dog_cat_question_1
main

搜索帮助