代码拉取完成,页面将自动刷新
import tensorflow as tf
import gradio as gr
import numpy as np
from PIL import Image
# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(_, _), (X_test, y_test) = mnist.load_data()
# 加载保存的模型
try:
model = tf.keras.models.load_model('best_model.h5')
except FileNotFoundError:
print("无法找到保存的模型文件")
exit()
except Exception as e:
print("加载模型时出错:", str(e))
exit()
# 定义预处理函数
def preprocess(image):
image = Image.fromarray(image) # 将数组转换为图像对象
image = image.resize((28, 28)).convert('L') # 调整图像大小并转换为灰度图像
image_array = np.array(image) # 将图像转换为NumPy数组
normalized_image = image_array / 255.0 # 对图像像素进行归一化
reshaped_image = normalized_image.reshape((1, 28, 28, 1)) # 调整图像形状以适应模型的输入
return reshaped_image
# 定义预测函数
def predict(image):
preprocessed_image = preprocess(image) # 预处理输入图像
predicted_digit = np.argmax(model.predict(preprocessed_image)) # 使用模型进行预测
return str(predicted_digit)
# 创建Gradio界面
iface = gr.Interface(fn=predict, inputs='sketchpad', outputs='label')
# 启动Gradio界面
iface.launch()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。