代码拉取完成,页面将自动刷新
import numpy as np
import ctypes
_libeasynn = ctypes.CDLL("./libeasynn.so")
def check_and_raise(ret):
if ret != 0:
raise Exception("libeasynn error: code %d" % ret)
def to_float_or_ndarray(c_dim, c_shape, c_data):
if c_dim.value == 0:
return c_data[0]
shape = [c_shape[k] for k in range(c_dim.value)]
N = 1
for s in shape:
N *= s
print("shape =", shape, "N =", N)
flat = np.array([c_data[i] for i in range(N)])
return flat.reshape(shape)
class Eval:
def __init__(self, evaluation):
self.evaluation = evaluation
def __call__(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, (int, float)):
_libeasynn.add_kwargs_double(
self.evaluation, k.encode(), ctypes.c_double(v))
elif hasattr(v, "shape"):
flat = v.flatten()
num_data = flat.shape[0]
data = [flat[i] for i in range(num_data)]
_libeasynn.add_kwargs_ndarray(
self.evaluation, k.encode(), v.ndim, v.ctypes.shape,
(ctypes.c_double*num_data)(*data))
else:
raise Exception("%s: kwargs must be float or int or ndarray" % k)
c_dim = ctypes.c_int()
c_shape = ctypes.POINTER(ctypes.c_size_t)()
c_data = ctypes.POINTER(ctypes.c_double)()
check_and_raise(_libeasynn.execute(self.evaluation,
ctypes.byref(c_dim), ctypes.byref(c_shape), ctypes.byref(c_data)))
return to_float_or_ndarray(c_dim, c_shape, c_data)
class Builder:
def __init__(self):
self.program = _libeasynn.create_program()
def append(self, expr):
inputs = [ex.id for ex in expr.inputs]
num_inputs = len(inputs)
op = expr.op
_libeasynn.append_expression(
self.program, expr.id,
op.name.encode(), op.op_type.encode(),
(ctypes.c_int*num_inputs)(*inputs), num_inputs)
for k, v in op.parameters.items():
if isinstance(v, (int, float)):
check_and_raise(_libeasynn.add_op_param_double(
self.program, k.encode(), ctypes.c_double(v)))
elif hasattr(v, "shape"):
flat = v.flatten()
num_data = flat.shape[0]
data = [flat[i] for i in range(num_data)]
check_and_raise(_libeasynn.add_op_param_ndarray(
self.program, k.encode(), v.ndim, v.ctypes.shape,
(ctypes.c_double*num_data)(*data)))
else:
raise Exception("%s: op params must be float or int or ndarray: %s" % (expr, k))
def build(self):
return Eval(_libeasynn.build(self.program))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。