6 Star 29 Fork 4

andy-upp/tensor-calcu-lib

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
easynn_cpp.py 2.79 KB
一键复制 编辑 原始数据 按行查看 历史
andy-upp 提交于 2020-08-11 13:33 . 整理cpp测试文件,完善文件名
import numpy as np
import ctypes
_libeasynn = ctypes.CDLL("./libeasynn.so")
def check_and_raise(ret):
if ret != 0:
raise Exception("libeasynn error: code %d" % ret)
def to_float_or_ndarray(c_dim, c_shape, c_data):
if c_dim.value == 0:
return c_data[0]
shape = [c_shape[k] for k in range(c_dim.value)]
N = 1
for s in shape:
N *= s
print("shape =", shape, "N =", N)
flat = np.array([c_data[i] for i in range(N)])
return flat.reshape(shape)
class Eval:
def __init__(self, evaluation):
self.evaluation = evaluation
def __call__(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, (int, float)):
_libeasynn.add_kwargs_double(
self.evaluation, k.encode(), ctypes.c_double(v))
elif hasattr(v, "shape"):
flat = v.flatten()
num_data = flat.shape[0]
data = [flat[i] for i in range(num_data)]
_libeasynn.add_kwargs_ndarray(
self.evaluation, k.encode(), v.ndim, v.ctypes.shape,
(ctypes.c_double*num_data)(*data))
else:
raise Exception("%s: kwargs must be float or int or ndarray" % k)
c_dim = ctypes.c_int()
c_shape = ctypes.POINTER(ctypes.c_size_t)()
c_data = ctypes.POINTER(ctypes.c_double)()
check_and_raise(_libeasynn.execute(self.evaluation,
ctypes.byref(c_dim), ctypes.byref(c_shape), ctypes.byref(c_data)))
return to_float_or_ndarray(c_dim, c_shape, c_data)
class Builder:
def __init__(self):
self.program = _libeasynn.create_program()
def append(self, expr):
inputs = [ex.id for ex in expr.inputs]
num_inputs = len(inputs)
op = expr.op
_libeasynn.append_expression(
self.program, expr.id,
op.name.encode(), op.op_type.encode(),
(ctypes.c_int*num_inputs)(*inputs), num_inputs)
for k, v in op.parameters.items():
if isinstance(v, (int, float)):
check_and_raise(_libeasynn.add_op_param_double(
self.program, k.encode(), ctypes.c_double(v)))
elif hasattr(v, "shape"):
flat = v.flatten()
num_data = flat.shape[0]
data = [flat[i] for i in range(num_data)]
check_and_raise(_libeasynn.add_op_param_ndarray(
self.program, k.encode(), v.ndim, v.ctypes.shape,
(ctypes.c_double*num_data)(*data)))
else:
raise Exception("%s: op params must be float or int or ndarray: %s" % (expr, k))
def build(self):
return Eval(_libeasynn.build(self.program))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/andy-upp/tensor-calcu-lib.git
[email protected]:andy-upp/tensor-calcu-lib.git
andy-upp
tensor-calcu-lib
tensor-calcu-lib
master

搜索帮助