代码拉取完成,页面将自动刷新
同步操作将从 mynameisi/OPTIMAL_KNN_MNIST_QUESTION 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
# 设置 Pinecone API 密钥
api_key="dfe20bc4-6606-43ef-9941-4d9aaeb10d48"
# 创建 Pinecone 实例
pc = Pinecone(api_key=api_key)
# 检查并创建 Pinecone 索引
index_name = "mnist-index"
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=64, # 对于手写数字数据集,每个样本有64个特征
metric='euclidean', # 使用欧几里得距离进行相似度计算
spec=ServerlessSpec(cloud='aws', region='us-east-1') # 使用正确的区域和云提供商
)
index = pc.Index(index_name)
# 加载手写数字数据集
digits = load_digits()
X, y = digits.data, digits.target
# 将数据集划分为训练集(80%)和测试集(20%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 将训练数据上传到 Pinecone,分批次进行
batch_size = 1000
for i in range(0, len(X_train), batch_size):
batch_vectors = [{"id": str(i + j), "values": X_train[i + j].tolist(), "metadata": {"label": int(y_train[i + j])}}
for j in range(min(batch_size, len(X_train) - i))]
index.upsert(batch_vectors)
# 使用 Pinecone 查找最近邻居
y_pred = []
k = 11
print("测试过程进度:")
for x in tqdm(X_test, desc="正在测试"):
# 查询时需要确保每个查询向量的维度与创建索引时的维度一致
try:
query_response = index.query(vector=x.tolist(), top_k=k, include_metadata=True)
neighbors = query_response['matches']
# 检查查询返回的邻居数量
if not neighbors: # 如果未找到任何邻居
print("未找到任何邻居,查询向量:", x.tolist())
y_pred.append(-1) # 使用一个占位符值来表示未找到邻居的情况
continue
# 获取最近的k个邻居的标签
neighbor_labels = [match['metadata']['label'] for match in neighbors]
# 使用多数投票法确定最终标签
predicted_label = max(set(neighbor_labels), key=neighbor_labels.count)
y_pred.append(predicted_label)
except Exception as e:
print(f"查询时出错:{e}")
y_pred.append(-1) # 在出现异常时使用占位符值
# 计算准确率(过滤掉未找到邻居的情况)
y_pred_filtered = [pred for pred in y_pred if pred != -1]
y_test_filtered = [y_test[i] for i in range(len(y_pred)) if y_pred[i] != -1]
accuracy = accuracy_score(y_test_filtered, y_pred_filtered)
print(f"使用 Pinecone 和 KNN 模型的准确率为: {accuracy:.4f}")
# 删除 Pinecone 索引
#pc.delete_index(index_name)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。