代码拉取完成,页面将自动刷新
print("Initializing...")
# pytorch
import torch
import torchvision.transforms as transforms
# 网络结构
from models.net import Net
# WebUI
import gradio as gr
# 图像处理
from PIL import Image
# 命令行参数获取
import utils.parameters
# 读取命令行参数
params_parser = utils.parameters.get_run_args()
print("Loading model...")
# 决定运行设备
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 实例化网络
model = Net()
# 加载模型参数
model.load_state_dict(
torch.load(
"./checkpoints/" + params_parser.model_name,
weights_only=True,
map_location=torch.device(device),
)
)
# 设置模型加载到设备
model = model.to(device)
model.eval() # 切换到推理模式
print("Model loaded.")
# 定义图像预处理函数
def preprocess_image(image_dict):
# 从字典中获取复合图像
image = image_dict["composite"]
# 创建一个与原图像大小相同的白色背景
white_background = Image.new("RGB", image.size, (255, 255, 255))
# 将原图像粘贴到白色背景上,使用Alpha通道作为掩码
white_background.paste(image, (0, 0), image)
# 将图像转换为二值图像
image = white_background.convert("1")
# 反转颜色,因为MNIST数据集是黑底白字
image = transforms.functional.invert(image)
# 缩放到28x28像素
image = image.resize((28, 28))
# 转换为张量
image = transforms.ToTensor()(image)
# 把第一个批次添加到设备
image = image.unsqueeze(0).to(device)
return image
# 定义预测函数
def predict_digit(image):
# 预处理图像
image = preprocess_image(image)
# 前向传播
with torch.no_grad():
output = model(image)
# 获取预测结果
prediction = torch.argmax(output, dim=1).item()
return prediction
print("Loading Gradio...")
# 创建Gradio界面
with gr.Blocks() as demo:
with gr.Row():
# 创建画布
sketchpad = gr.Sketchpad(
label="Draw a digit",
type="pil",
canvas_size=(280, 280),
)
# 创建输出文本框
output_text = gr.Textbox(label="Prediction")
# 创建按钮
predict_button = gr.Button("Predict")
# 设置按钮点击事件
predict_button.click(fn=predict_digit, inputs=sketchpad, outputs=output_text)
# 启动Web服务
demo.launch(share=params_parser.share)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。