1 Star 0 Fork 248

xfyzs/OPTIMAL_KNN_MNIST_QUESTION_1

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_train.py 3.00 KB
一键复制 编辑 原始数据 按行查看 历史
xfyzs 提交于 2024-09-23 13:38 . 9
# 导入必要的库和模块
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import numpy as np
from collections import Counter
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Pinecone API key 和配置
api_key = "970785b0-8983-40e6-acda-afd852f77db0 "
pinecone = Pinecone(api_key=api_key)
# 索引名称
index_name = "mnist-index"
# 获取现有索引列表
existing_indexes = pinecone.list_indexes()
# 检查索引是否存在,如果存在就删除
if any(index['name'] == index_name for index in existing_indexes):
logging.info(f"索引 '{index_name}' 已存在,正在删除...")
pinecone.delete_index(index_name)
logging.info(f"索引 '{index_name}' 已成功删除。")
else:
logging.info(f"索引 '{index_name}' 不存在,将创建新索引。")
# 创建新索引
logging.info(f"正在创建新索引 '{index_name}'...")
pinecone.create_index(
name=index_name,
dimension=64, # MNIST 每个图像展平后是一个 64 维向量
metric="euclidean", # 使用欧氏距离
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logging.info(f"索引 '{index_name}' 创建成功。")
# 连接到索引
index = pinecone.Index(index_name)
logging.info(f"已成功连接到索引 '{index_name}'。")
# 加载 MNIST 数据集
digits = load_digits(n_class=10)
X = digits.data
y = digits.target
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化一个空列表,用于存储转换后的向量数据
vectors = []
# 使用 tqdm 包装 range 函数以显示进度条
for i in tqdm(range(len(X_train)), desc="Preparing vectors"):
vector_id = str(i)
vector_values = X_train[i].tolist()
metadata = {"label": int(y_train[i])}
vectors.append((vector_id, vector_values, metadata))
# 定义批处理大小,每批最多包含 1000 个向量
batch_size = 1000
# 使用 tqdm 包装 range 函数以显示进度条
for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading vectors"):
batch = vectors[i:i + batch_size]
index.upsert(batch)
logging.info(f"成功创建索引,并上传了 {len(X_train)} 条数据。")
# 测试准确率
correct_predictions = 0
for x_test, y_test_label in tqdm(zip(X_test, y_test), total=len(X_test), desc="Testing accuracy"):
query_data = x_test.tolist()
results = index.query(vector=query_data, top_k=11, include_metadata=True)
labels = [match['metadata']['label'] for match in results['matches']]
if labels:
final_prediction = Counter(labels).most_common(1)[0][0]
else:
final_prediction = None
if final_prediction == y_test_label:
correct_predictions += 1
accuracy = correct_predictions / len(X_test)
logging.info(f"当k=11时,使用Pinecone的准确率为: {accuracy:.4f}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xfyzs/optimal_knn_mnist_question_1.git
[email protected]:xfyzs/optimal_knn_mnist_question_1.git
xfyzs
optimal_knn_mnist_question_1
OPTIMAL_KNN_MNIST_QUESTION_1
main

搜索帮助