1 Star 0 Fork 248

低矮哦大/OPTIMAL_KNN_MNIST_QUESTION_1

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_train.py 3.87 KB
一键复制 编辑 原始数据 按行查看 历史
李鹏飞 提交于 2024-09-18 23:45 . added a file
import logging
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from collections import Counter
from pinecone import Pinecone, ServerlessSpec
from tqdm import tqdm
import pandas as pd
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# API密钥和索引名称
api_key = "e4df12f5-ebc2-4814-b493-3e166d6de302"
index_name = "mnist-index"
# 创建Pinecone客户端
pinecone = Pinecone(api_key=api_key)
# 检查索引是否存在,如果存在则删除
existing_indexes = pinecone.list_indexes()
if any(index['name'] == index_name for index in existing_indexes):
logger.info(f"索引 '{index_name}' 已存在,正在删除...")
pinecone.delete_index(index_name)
logger.info(f"索引 '{index_name}' 已成功删除。")
# 创建新索引
logger.info(f"正在创建新索引 '{index_name}'...")
pinecone.create_index(
name=index_name,
dimension=784, # MNIST 每个图像展平后是一个 784 维向量
metric="euclidean", # 使用欧氏距离
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logger.info(f"索引 '{index_name}' 创建成功。")
# 加载MNIST数据集
logger.info("加载MNIST数据集...")
mnist = fetch_openml('mnist_784', version=1, parser='auto')
X, y = mnist.data, mnist.target
# 确保 X 是 NumPy 数组
X = X.values if isinstance(X, pd.DataFrame) else X
y = y.values if isinstance(y, pd.Series) else y
logger.info("MNIST 数据集加载完成。")
# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化一个空列表,用于存储转换后的向量数据
vectors = []
# 遍历训练集的所有样本,将数据转换为 Pinecone 可接受的格式
logger.info("开始处理并上传训练数据...")
for i in tqdm(range(len(X_train)), desc="上传数据到Pinecone"):
# 使用样本的索引作为向量的唯一标识符
vector_id = str(i)
# 将 NumPy 数组转换为 Python 列表,并确保值为 float 类型
vector_values = [float(val) for val in X_train[i]]
# 创建元数据字典,包含该样本的真实标签
metadata = {"label": int(y_train[i])}
# 将转换后的数据(ID、向量值、元数据)作为元组添加到 vectors 列表中
vectors.append((vector_id, vector_values, metadata))
# 连接到索引
index = pinecone.Index(index_name)
logger.info(f"已成功连接到索引 '{index_name}'。")
# 定义批处理大小,每批最多包含 1000 个向量
batch_size = 100
# 使用步长为 batch_size 的 range 函数,实现分批处理
for i in tqdm(range(0, len(vectors), batch_size), desc="上传数据到Pinecone"):
# 从 vectors 列表中切片获取一批数据
batch = vectors[i:i + batch_size]
# 使用 upsert 方法将这批数据上传到 Pinecone 索引中
index.upsert(batch)
logger.info("成功创建索引,并上传了1437条数据")
# 使用测试集评估k=11时的准确率
correct_predictions = 0
logger.info("开始测试准确率...")
for i in tqdm(range(len(X_test)), desc="测试K=11准确率"):
query_data = [float(val) for val in X_test[i]] # 确保查询数据也是 float 类型
results = index.query(
vector=query_data,
top_k=11, # 返回距离最近的 11 个结果
include_metadata=True # 同时返回每个向量的元数据(包括标签)
)
labels = [int(match['metadata']['label']) for match in results['matches']]
if labels:
final_prediction = Counter(labels).most_common(1)[0][0]
else:
final_prediction = None
if final_prediction == int(y_test[i]):
correct_predictions += 1
accuracy = correct_predictions / len(X_test)
logger.info(f"当k=11时,使用Pinecone的准确率为:{accuracy:.4f}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/low-short-big/optimal_knn_mnist_question_1.git
[email protected]:low-short-big/optimal_knn_mnist_question_1.git
low-short-big
optimal_knn_mnist_question_1
OPTIMAL_KNN_MNIST_QUESTION_1
main

搜索帮助