1 Star 2 Fork 1

wwfu/onnx-helper

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
onnx_run.py 2.53 KB
一键复制 编辑 原始数据 按行查看 历史
wwfu 提交于 2022-09-23 05:25 . first commit
#!/usr/bin/python
# -*- coding: utf-8 -*-
import tqdm
import onnx
import onnxruntime
import numpy as np
from pathlib import Path
def onnx_predict_for_dir(onnx_path, input_dir, output_dir, fp16=False):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 检查模型并获取session
#onnx_model = onnx.load(onnx_path)
#onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(onnx_path)
input_shape = ort_session.get_inputs()[0].shape
input_type = np.float16 if fp16 else np.float32
input_shape[0] = -1
for input_bin in tqdm.tqdm(Path(input_dir).iterdir()):
# 准备输入数据
input_data = np.fromfile(input_bin.__str__(), dtype=input_type)
input_data = input_data.reshape(input_shape)
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
# 推理并保存结果
ort_outs = ort_session.run(None, ort_inputs)
output_bin = output_dir / input_bin.name
ort_outs[0].tofile(output_bin)
print("Done.")
def onnx_perdict_for_bin(onnx_path, input_bin, output_bin, fp16=False):
ort_session = onnxruntime.InferenceSession(onnx_path)
input_shape = ort_session.get_inputs()[0].shape
input_type = ort_session.get_inputs()[0].type
input_type = np.float16 if fp16 else np.float32
input_shape[0] = -1
# 准备输入数据
input_data = np.fromfile(input_bin.__str__(), dtype=input_type)
input_data = input_data.reshape(input_shape)
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
# 推理并保存结果
ort_outs = ort_session.run(None, ort_inputs)
ort_outs[0].tofile(output_bin)
print("Done.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='onnx infer.')
parser.add_argument('--onnx', type=str, metavar='PATH',
help='path to ONNX file.')
parser.add_argument('--input', type=str, metavar='DIR/PATH',
help='path to dataset')
parser.add_argument('--fp16', action='store_true', help='data type of input')
parser.add_argument('--output', type=str, metavar='DIR/PATH', help='')
args = parser.parse_args()
if Path(args.input).is_dir():
onnx_predict_for_dir(args.onnx, args.input, args.output, args.fp16)
else:
output_path = Path(args.output)
if output_path.is_dir():
output_path = output_path/Path(args.input).name.replace('.bin', '_output.bin')
onnx_perdict_for_bin(args.onnx, args.input, output_path, args.fp16)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wwfu/onnx-helper.git
[email protected]:wwfu/onnx-helper.git
wwfu
onnx-helper
onnx-helper
master

搜索帮助