1 Star 0 Fork 0

XY/LearnTheUseOfAI

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main_run.py 2.37 KB
一键复制 编辑 原始数据 按行查看 历史
XY0797 提交于 2024-11-27 21:22 . add: 添加跨设备加载模型支持
print("Initializing...")
# pytorch
import torch
import torchvision.transforms as transforms
# 网络结构
from models.net import Net
# WebUI
import gradio as gr
# 图像处理
from PIL import Image
# 命令行参数获取
import utils.parameters
# 读取命令行参数
params_parser = utils.parameters.get_run_args()
print("Loading model...")
# 决定运行设备
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 实例化网络
model = Net()
# 加载模型参数
model.load_state_dict(
torch.load(
"./checkpoints/" + params_parser.model_name,
weights_only=True,
map_location=torch.device(device),
)
)
# 设置模型加载到设备
model = model.to(device)
model.eval() # 切换到推理模式
print("Model loaded.")
# 定义图像预处理函数
def preprocess_image(image_dict):
# 从字典中获取复合图像
image = image_dict["composite"]
# 创建一个与原图像大小相同的白色背景
white_background = Image.new("RGB", image.size, (255, 255, 255))
# 将原图像粘贴到白色背景上,使用Alpha通道作为掩码
white_background.paste(image, (0, 0), image)
# 将图像转换为二值图像
image = white_background.convert("1")
# 反转颜色,因为MNIST数据集是黑底白字
image = transforms.functional.invert(image)
# 缩放到28x28像素
image = image.resize((28, 28))
# 转换为张量
image = transforms.ToTensor()(image)
# 把第一个批次添加到设备
image = image.unsqueeze(0).to(device)
return image
# 定义预测函数
def predict_digit(image):
# 预处理图像
image = preprocess_image(image)
# 前向传播
with torch.no_grad():
output = model(image)
# 获取预测结果
prediction = torch.argmax(output, dim=1).item()
return prediction
print("Loading Gradio...")
# 创建Gradio界面
with gr.Blocks() as demo:
with gr.Row():
# 创建画布
sketchpad = gr.Sketchpad(
label="Draw a digit",
type="pil",
canvas_size=(280, 280),
)
# 创建输出文本框
output_text = gr.Textbox(label="Prediction")
# 创建按钮
predict_button = gr.Button("Predict")
# 设置按钮点击事件
predict_button.click(fn=predict_digit, inputs=sketchpad, outputs=output_text)
# 启动Web服务
demo.launch(share=params_parser.share)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/eq_software/learn-the-use-of-ai.git
[email protected]:eq_software/learn-the-use-of-ai.git
eq_software
learn-the-use-of-ai
LearnTheUseOfAI
HandwrittenDigitRecognition

搜索帮助