1 Star 0 Fork 0

BG/飞桨第一案例-手写数字识别

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
a.py 1.82 KB
一键复制 编辑 原始数据 按行查看 历史
gyluo 提交于 2023-08-14 15:38 . update a.py.
import paddle
import numpy as np
from PIL import Image
from paddle.vision.transforms import Normalize
import os
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)
# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy())
# # 模型训练
model.fit(train_dataset, epochs=1, batch_size=64, verbose=1)
# # 模型评估
# model.evaluate(test_dataset, batch_size=64, verbose=1)
# 保存模型
model.save('./output/mnist')
# 加载模型
model.load('output/mnist')
# 从测试集中取出一张图片
# img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
# img_batch = np.expand_dims(img.astype('float32'), axis=0)
path = "E:\\code\\python\\1"
file_list = os.listdir(path)
for file_name in file_list:
print(file_name)
image = Image.open(os.path.join(path, file_name))
img = transform(image)
img_batch = np.expand_dims(img.astype('float32'), axis=0)
# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('pred label: {}'.format(pred_label))
# 可视化图片
# from matplotlib import pyplot as plt
# import pylab
# plt.imshow(img[0])
# pylab.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/bg_gyluo/scanhandover.git
[email protected]:bg_gyluo/scanhandover.git
bg_gyluo
scanhandover
飞桨第一案例-手写数字识别
master

搜索帮助