1 Star 0 Fork 246

荆传智/faiss_dog_cat_question

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
FaissKNeighbors.py 1.26 KB
一键复制 编辑 原始数据 按行查看 历史
荆传智 提交于 2024-10-13 21:35 +08:00 . added a file
import numpy as np
import faiss
class FaissKNeighbors:
def __init__(self, k=1, res=None):
self.index = None # FAISS索引,用于存储训练数据
self.y = None # 训练数据的标签
self.k = k # 最近邻个数
self.res = res # FAISS GPU资源对象
def fit(self, X, y):
self.index = faiss.IndexFlatL2(X.shape[1]) # 初始化FAISS索引,使用L2距离
# 如果提供了GPU资源,则将索引转移到GPU上
if self.res is not None:
self.index = faiss.index_cpu_to_gpu(self.res, 0, self.index)
self.index.add(X.astype(np.float32)) # 将训练数据添加到索引
self.y = y
def predict(self, X):
# 搜索X中每个向量的k个最近邻
distances, indices = self.index.search(X.astype(np.float32), self.k)
votes = self.y[indices] # 获取最近邻的标签
# 通过投票机制确定最终的预测标签
predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
return predictions
def score(self, X, y):
predictions = self.predict(X) # 获取预测结果
accuracy = np.mean(predictions == y) # 计算准确率
return accuracy
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jing-chuanzhi/faiss_dog_cat_question.git
[email protected]:jing-chuanzhi/faiss_dog_cat_question.git
jing-chuanzhi
faiss_dog_cat_question
faiss_dog_cat_question
main

搜索帮助