import os
import joblib
import gradio as gr
import numpy as np
from PIL import Image

# 指定模型文件的路径
model_filename = r'D:/下载/optimal_knn_mnist_question/best_knn_model.pkl'

# 检查模型文件是否存在
if not os.path.isfile(model_filename):
    print(f"Error: {model_filename} not found.")
    # 如果模型文件不存在,退出程序
    exit()
else:
    # 加载模型
    best_knn_model = joblib.load(model_filename)


# 定义预测函数,这个函数将用于Gradio接口进行预测
def predict_digit(drawing):
    if not drawing:
        return "Please draw a digit."

    # 将输入的图像转换为模型需要的格式
    image_array = np.array(drawing, dtype=np.float32)  # 直接使用drawing
    image_array = image_array.reshape(28, 28)  # 调整数组形状
    image_array = (image_array > 0).astype(int)  # 将图像转换为二值图像
    prediction = best_knn_model.predict(image_array.reshape(1, -1))
    return prediction[0]


# 创建Gradio接口,这个接口将用于用户输入和显示预测结果
iface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(label="Draw a Digit", type="numpy"),
    outputs="label",
    title="Digit Prediction with KNN",
    description="Draw a digit in the box below to get a prediction."
)

# 启动Gradio接口,用户可以通过这个接口进行交互
iface.launch()