1 Star 0 Fork 248

孔凡兵/OPTIMAL_KNN_MNIST_QUESTION_1

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Pinecone_train.py 3.96 KB
一键复制 编辑 原始数据 按行查看 历史
孔凡兵 提交于 2024-09-15 20:43 . added a file
import logging
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from pinecone import Pinecone, ServerlessSpec
from tqdm import tqdm
from collections import Counter
# 配置 logging
logging.basicConfig(filename='pinecone_mnist.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 初始化 Pinecone
api_key = "1aabd886-6783-4f31-90d3-5cf6110be507"
pinecone = Pinecone(api_key=api_key)
index_name = "mnist-index"
# 获取现有索引列表并删除索引(如果存在)
existing_indexes = pinecone.list_indexes()
if any(index['name'] == index_name for index in existing_indexes):
logging.info(f"索引 '{index_name}' 已存在,正在删除...")
pinecone.delete_index(index_name)
logging.info(f"索引 '{index_name}' 已成功删除。")
# 创建新索引
logging.info(f"正在创建新索引 '{index_name}'...")
pinecone.create_index(
name=index_name,
dimension=64,
metric="euclidean",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logging.info(f"索引 '{index_name}' 创建成功。")
# 连接到索引
index = pinecone.Index(index_name)
logging.info(f"已成功连接到索引 '{index_name}'。")
# 加载数据集
digits = load_digits(n_class=10)
X = digits.data
y = digits.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 标准化数据
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 初始化 PCA
pca = PCA(n_components=64)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)
# 数据扩增
def data_augmentation(X, y, num_samples=100):
"""数据扩增函数"""
from sklearn.utils import shuffle
X_augmented = np.zeros((num_samples * X.shape[0], X.shape[1]))
y_augmented = np.zeros(num_samples * y.shape[0])
for i in range(num_samples):
X_augmented[i * X.shape[0]:(i + 1) * X.shape[0], :] = X
y_augmented[i * y.shape[0]:(i + 1) * y.shape[0]] = y
X_augmented, y_augmented = shuffle(X_augmented, y_augmented)
return X_augmented, y_augmented
X_train_augmented, y_train_augmented = data_augmentation(X_train_pca, y_train, num_samples=5)
vectors = []
for i in tqdm(range(len(X_train_augmented)), desc="上传数据进度"):
vector_id = str(i)
vector_values = X_train_augmented[i].tolist()
metadata = {"label": int(y_train_augmented[i])}
vectors.append((vector_id, vector_values, metadata))
# 上传数据到 Pinecone
batch_size = 1000
for i in tqdm(range(0, len(vectors), batch_size), desc="批量上传进度"):
batch = vectors[i:i + batch_size]
index.upsert(batch)
logging.info(f"成功创建索引,并上传了 {len(vectors)} 条数据。")
# 测试模型的准确率
correct = 0
total = len(X_test)
for i in tqdm(range(total), desc="测试进度"):
vector_id = str(i + len(X_train_augmented))
vector_values = X_test_pca[i].tolist()
results = index.query(
vector=vector_values,
top_k=11,
include_metadata=True
)
# Debugging logs
logging.debug(f"查询向量ID: {vector_id}")
logging.debug(f"查询结果: {results}")
# 检查是否有匹配项
if not results['matches']:
logging.warning(f"没有找到匹配项,向量ID: {vector_id}")
continue
labels = [match['metadata']['label'] for match in results['matches']]
# 处理标签
if not labels:
logging.warning(f"查询结果中没有标签,向量ID: {vector_id}")
continue
final_prediction = Counter(labels).most_common(1)[0][0]
if final_prediction == y_test[i]:
correct += 1
accuracy = correct / total
logging.info(f"当 k=11 时,使用 Pinecone 的准确率为 {accuracy:.2%}。")
print(f"当 k=11 时,使用 Pinecone 的准确率为 {accuracy:.2%}。")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/FanBing-Kong/optimal_knn_mnist_question_1.git
[email protected]:FanBing-Kong/optimal_knn_mnist_question_1.git
FanBing-Kong
optimal_knn_mnist_question_1
OPTIMAL_KNN_MNIST_QUESTION_1
main

搜索帮助