import os import joblib import gradio as gr import numpy as np from PIL import Image # 指定模型文件的路径 model_filename = r'D:/下载/optimal_knn_mnist_question/best_knn_model.pkl' # 检查模型文件是否存在 if not os.path.isfile(model_filename): print(f"Error: {model_filename} not found.") # 如果模型文件不存在,退出程序 exit() else: # 加载模型 best_knn_model = joblib.load(model_filename) # 定义预测函数,这个函数将用于Gradio接口进行预测 def predict_digit(drawing): if not drawing: return "Please draw a digit." # 将输入的图像转换为模型需要的格式 image_array = np.array(drawing, dtype=np.float32) # 直接使用drawing image_array = image_array.reshape(28, 28) # 调整数组形状 image_array = (image_array > 0).astype(int) # 将图像转换为二值图像 prediction = best_knn_model.predict(image_array.reshape(1, -1)) return prediction[0] # 创建Gradio接口,这个接口将用于用户输入和显示预测结果 iface = gr.Interface( fn=predict_digit, inputs=gr.Sketchpad(label="Draw a Digit", type="numpy"), outputs="label", title="Digit Prediction with KNN", description="Draw a digit in the box below to get a prediction." ) # 启动Gradio接口,用户可以通过这个接口进行交互 iface.launch()