1 Star 0 Fork 248

吼吼/OPTIMAL_KNN_MNIST_QUESTION

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
optimal_knn.py 2.74 KB
一键复制 编辑 原始数据 按行查看 历史
吼吼 提交于 2024-09-21 16:17 . added a file
import os
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
# 设置 Pinecone API 密钥
api_key = "d86bcb40-a767-490b-b198-d5101fefbe7b"
# 创建 Pinecone 实例
pc = Pinecone(api_key=api_key)
# 检查并创建 Pinecone 索引
index_name = "mnist-index"
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=64, # 对于手写数字数据集,每个样本有64个特征
metric='euclidean', # 使用欧几里得距离进行相似度计算
spec=ServerlessSpec(cloud='aws', region='us-east-1') # 使用正确的区域和云提供商
)
index = pc.Index(index_name)
# 加载手写数字数据集
digits = load_digits()
X, y = digits.data, digits.target
# 将数据集划分为训练集(80%)和测试集(20%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 将训练数据上传到 Pinecone,分批次进行
batch_size = 1000
for i in range(0, len(X_train), batch_size):
batch_vectors = [{"id": str(i + j), "values": X_train[i + j].tolist(), "metadata": {"label": int(y_train[i + j])}}
for j in range(min(batch_size, len(X_train) - i))]
index.upsert(batch_vectors)
# 使用 Pinecone 查找最近邻居
y_pred = []
k = 11
print("测试过程进度:")
for x in tqdm(X_test, desc="正在测试"):
# 查询时需要确保每个查询向量的维度与创建索引时的维度一致
try:
query_response = index.query(vector=x.tolist(), top_k=k, include_metadata=True)
neighbors = query_response['matches']
# 检查查询返回的邻居数量
if not neighbors: # 如果未找到任何邻居
print("未找到任何邻居,查询向量:", x.tolist())
y_pred.append(-1) # 使用一个占位符值来表示未找到邻居的情况
continue
# 获取最近的k个邻居的标签
neighbor_labels = [match['metadata']['label'] for match in neighbors]
# 使用多数投票法确定最终标签
predicted_label = max(set(neighbor_labels), key=neighbor_labels.count)
y_pred.append(predicted_label)
except Exception as e:
print(f"查询时出错:{e}")
y_pred.append(-1) # 在出现异常时使用占位符值
# 计算准确率(过滤掉未找到邻居的情况)
y_pred_filtered = [pred for pred in y_pred if pred != -1]
y_test_filtered = [y_test[i] for i in range(len(y_pred)) if y_pred[i] != -1]
accuracy = accuracy_score(y_test_filtered, y_pred_filtered)
print(f"使用 Pinecone 和 KNN 模型的准确率为: {accuracy:.4f}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/XCJLXRH/optimal_knn_mnist_question.git
git@gitee.com:XCJLXRH/optimal_knn_mnist_question.git
XCJLXRH
optimal_knn_mnist_question
OPTIMAL_KNN_MNIST_QUESTION
main

搜索帮助