1 Star 0 Fork 248

发故宫/OPTIMAL_KNN_MNIST_QUESTION_2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_train.py 3.59 KB
一键复制 编辑 原始数据 按行查看 历史
发故宫 提交于 2024-09-14 15:37 . added a file
"""参考使用Pinecone进行数字判断的实例撰写代码实现
- 用80%的mnist数据创建Pinecone的索引
- 用20%的数据测试当k=11时准确率
- 最终用logging打印:
成功创建索引,并上传了1437条数据
当k=11是,使用Pinecone的准确率
上传数据和测试k的准确率的时候都要有进度条
用logging打印的信息需要有日期"""
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
# 设置 Pinecone API 密钥
api_key = "0faf78dd-2331-43ae-b772-0f2b5986217d"
# 创建 Pinecone 实例
pc = Pinecone(api_key=api_key)
# 创建索引
index_name = "mnist-index"
# 获取现有索引列表
existing_indexes = pc.list_indexes()
# 检查索引是否存在,如果存在就删除
if any(index['name'] == index_name for index in existing_indexes):
print(f"索引 '{index_name}' 已存在,正在删除...")
pc.delete_index(index_name)
print(f"索引 '{index_name}' 已成功删除。")
else:
print(f"索引 '{index_name}' 不存在,将创建新索引。")
# 检查并创建 Pinecone 索引
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=64, # 对于手写数字数据集,每个样本有64个特征
metric='euclidean', # 使用欧几里得距离进行相似度计算
spec=ServerlessSpec(cloud='aws', region='us-east-1') # 使用正确的区域和云提供商
)
index = pc.Index(index_name)
# 加载手写数字数据集
digits = load_digits()
X, y = digits.data, digits.target
# 将数据集划分为训练集(80%)和测试集(20%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 将训练数据上传到 Pinecone,分批次进行
batch_size = 1000
for i in range(0, len(X_train), batch_size):
batch_vectors = [{"id": str(i + j), "values": X_train[i + j].tolist(), "metadata": {"label": int(y_train[i + j])}}
for j in range(min(batch_size, len(X_train) - i))]
index.upsert(batch_vectors)
# 使用 Pinecone 查找最近邻居
y_pred = []
k = 11
print("测试过程进度:")
for x in tqdm(X_test, desc="正在测试"):
# 查询时需要确保每个查询向量的维度与创建索引时的维度一致
try:
query_response = index.query(vector=x.tolist(), top_k=k, include_metadata=True)
neighbors = query_response['matches']
# 检查查询返回的邻居数量
if not neighbors: # 如果未找到任何邻居
print("未找到任何邻居,查询向量:", x.tolist())
y_pred.append(-1) # 使用一个占位符值来表示未找到邻居的情况
continue
# 获取最近的k个邻居的标签
neighbor_labels = [match['metadata']['label'] for match in neighbors]
# 使用多数投票法确定最终标签
predicted_label = max(set(neighbor_labels), key=neighbor_labels.count)
y_pred.append(predicted_label)
except Exception as e:
print(f"查询时出错:{e}")
y_pred.append(-1) # 在出现异常时使用占位符值
# 计算准确率(过滤掉未找到邻居的情况)
y_pred_filtered = [pred for pred in y_pred if pred != -1]
y_test_filtered = [y_test[i] for i in range(len(y_pred)) if y_pred[i] != -1]
accuracy = accuracy_score(y_test_filtered, y_pred_filtered)
print(f"使用 Pinecone 和 KNN 模型的准确率为: {accuracy:.4f}")
# 删除 Pinecone 索引
#pc.delete_index(index_name)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/fa-forbidden-city/optimal_knn_mnist_question_1.git
[email protected]:fa-forbidden-city/optimal_knn_mnist_question_1.git
fa-forbidden-city
optimal_knn_mnist_question_1
OPTIMAL_KNN_MNIST_QUESTION_2
main

搜索帮助