【TVM教程】⾃定义relay算⼦
本⽂为tvm 教程的翻译版。这部分介绍了如何在tvm中添加新的relay算⼦,具体的是以⼀个累乘(cumprod)算⼦为例进⾏介绍。
新增relay算⼦基本是下⾯⼏个步骤:
定义新增算⼦的属性节点(Attribute Node),声明在编译时已知的固定参数
为新增算⼦编写类型关系,以集成到relay的类型系统中
使⽤C++RELAY_REGISTER_OP宏,为新增算⼦注册⽣命参数数量、类型、提⽰信息
算⼦的compute
注册算⼦的compute、schedule
定义C++函数,为新增算⼦⽣成调⽤节点,并为该函数注册 Python API hook
将上⾯的 Python API hook 封装成简洁的调⽤⽅式
为新的relay 算⼦编写测试
新增算⼦的属性节点
算⼦属性是编译期已知的参数。以卷积算⼦为例,strid、dilation就属于卷积算⼦的属性。这部分算⼦属性定义在include/tvm/relay/attrs/下。最终来说,我们期望定义有如下属性说明的算⼦,其python侧的接⼝如下所⽰
def cumprod(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumprod op. Return the cumulative inclusive product of the elements along
a given axis.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int, optional
Axis along which the cumulative product is computed. The default (None) is to compute
the cumprod over the flattened array.
dtype : string, optional
Type of the returned array and of the accumulator in which the elements are multiplied.
If dtype is not specified, it defaults to the dtype of data.
exclusive : bool, optional
If true will return exclusive product in which the first element is not
included. In other terms, if true, the j-th output element would be
numpy教程 pdfthe product of the first (j-1) elements. Otherwise, it would be the product of
the first j elements. The product of zero elements will be 1.
Returns
-
------
result : relay.Expr
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
.cumsum()有类似的接⼝。
因此,在定义我们新增算⼦(cumprod)属性时,需要选择操作的轴、数据类型和排他性作为属性字段。include/tvm/relay/attrs/transform.h ScanopAttrs 这⾥定义了对累加、累乘等操作的属性定义。对累乘来说就不需要额外定义了。
/*! \brief Attributes used in cumsum and cumprod operator */
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
Integer axis;
DataType dtype;
Bool exclusive = Bool(false);
TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(Bool(false));
}
};
但是如果是其他的算⼦,需要⾃⼰定义相应的属性节点。如BiasAdd就需要单独定义
struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
int axis;
TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);
}
};
类型推导 Type Relation
为了算⼦注册的灵活性以及relay算⼦有更好的泛化能⼒,relay算⼦通过输⼊输出之间的类型关系来实例化。
这些关系通过⼀系列的函数进⾏表⽰(这些函数是以算⼦输⼊输出类型为参数,返回满⾜类型关系的输⼊输出列表), 、、?
这包括编译期已知的输⼊输出的shape 信息
本质上,算⼦relation除了推到输出类型外,还能够强制指定类型规则(检查输⼊类型)。
然后就是官⽹教程的给的例⼦src/relay/op/。这⾥依旧是ScanopAttrs
TVM_REGISTER_NODE_TYPE(ScanopAttrs);
bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) {
// types: [data, output]
ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
const auto* data = types[0].as<TensorTypeNode>(); //输⼊的tensor信息
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Scanop: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<ScanopAttrs>(); //算⼦属性
auto dtype = param->dtype;
if (dtype.is_void()) {
dtype = data->dtype;
}
//设置输出tensor属性
if (param->axis.defined()) {
reporter->Assign(types[1], TensorType(data->shape, dtype));
} else {
auto prod = data->shape[0];
for (size_t i = 1; i < data->shape.size(); ++i) {
prod = prod * data->shape[i];
}
reporter->Assign(types[1], TensorType({prod}, dtype));
}
return true;
}
从上⾯的例⼦可以看出 XXXOpRel 的主要功能是根据输⼊类型确定输出类型。特别的, TensorType的构造函数可以看出,需要指定输出的shape信息,这部分主要⽬的就是infershape和infertype。
关联算⼦的参数数⽬、属性
这⼀步的操作,为⾃定义算⼦注册算⼦名称,通过调⽤接⼝增加算⼦注释。这⾥需要⽤到C++的宏RELAY_REGISTER_OP
涉及的参数含义如下:
Arity(参数数量)
位置参数的名称和描述
⽀持级别(1 表⽰内部实现;较⾼的数字表⽰较少的内部⽀持或外部⽀持的算⼦)
算⼦的类型关系
优化算⼦时有⽤的其他注释。
src/relay/op/
RELAY_REGISTER_OP("cumsum")
.describe(
R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.
add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cumsum", ScanopRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);
RELAY_REGISTER_OP("cumprod")
.describe(
R"doc(Return the cumulative product of the elements along a given axis.)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Cumprod", ScanopRel)
.
set_attr<TOpPattern>("TOpPattern", kOpaque);// 不融合
注:set_attr<TOpPattern>("TOpPattern", );此处表⽰融合算⼦是,跳过此算⼦。
编写的算⼦compute
到现在,我们已经实现了算⼦的接⼝,但是还缺少算⼦的compute逻辑。这部分内容超出了这个教程的范围。
对于cumprod和cumsum,CPU实现可以参考python/tvm/topi/scan.py,GPU实现可以参考python/tvm/topi/cuda/scan.py。
这⾥这两个的实现,直接在TIR基础上实现得到的。
def scanop(
data: Tensor,
binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"],
identity_value: "tvm.Expr",
op_name: str,
axis: Optional[int] = None,
dtype: Optional[str] = None,
exclusive: Optional[bool] = None,
) -> Tensor:
if dtype is None or dtype == "":
dtype = data.dtype
if exclusive is None:
exclusive = False
def maybe_cast(x):
if dtype != data.dtype:
return cast(x, dtype)
return x
axis_mul_before = 1
axis_mul_after = 1
if axis is None:
axis = 0
cumsum_axis_len = prod(data.shape)
shape = (cumsum_axis_len,)
else:
if not isinstance(axis, int):
axis = get_const_int(axis)
shape = data.shape
cumsum_axis_len = shape[axis]
if axis < 0:
axis = len(shape) + axis
for i, value in enumerate(shape, 0):
if i < axis:
axis_mul_before *= value
elif i > axis:
axis_mul_after *= value
def gen_ir(data_buf, out_buf):
ib = ate()
data_buf = ib.buffer_ptr(data_buf)
out_buf = ib.buffer_ptr(out_buf)
with ib.for_range(0, axis_mul_before * axis_mul_after, "fused", kind="parallel") as fused:            i = fused // axis_mul_after
j = fused % axis_mul_after
base_idx = i * cumsum_axis_len * axis_mul_after + j
if exclusive:
out_buf[base_idx] = cast(identity_value, dtype)
else:
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
k = _k + 1
cur_idx = base_idx + k * axis_mul_after
prev_idx = base_idx + (k - 1) * axis_mul_after
if exclusive:
out_buf[cur_idx] = binop(out_buf[prev_idx], maybe_cast(data_buf[prev_idx]))
else:
out_buf[cur_idx] = binop(out_buf[prev_idx], maybe_cast(data_buf[cur_idx]))
()
out_buf = decl_buffer(shape, dtype, "out_buf")
return extern(
[shape],
[data],
lambda ins, outs: gen_ir(ins[0], outs[0]),
dtype=dtype,
out_buffers=[out_buf],
name=op_name,
tag=op_name,
)
def cumsum(
data: Tensor,
axis: Optional[int] = None,
dtype: Optional[int] = None,
exclusive: Optional[bool] = None,
)
-> Tensor:
return scanop(
data=data,
binop=generic.add,
identity_value=0,
op_name="cumsum_generic",
axis=axis,
dtype=dtype,
exclusive=exclusive,
)
注册算⼦的compute、schedule
在实现了算⼦compute逻辑以后,需要与我们实现的算⼦接⼝绑定在⼀起。在TVM中,这就需要不仅实现算⼦的compute接⼝,还要实现对应的schedule。⽽strategy就是对compute选择合适的schedule。
以卷积算⼦为例,算⼦编译时,可能会发现这是⼀个depthwise卷积,进⽽去选择更⾼效的schedule实现。
⼀般情况下,仅仅考虑CPU、GPU版本即可。
python/tvm/relay/op/strategy/generic.py python/tvm/relay/op/strategy/cuda.py
def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""
def _compute_scanop(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, lusive)]
return _compute_scanop
@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cumsum), #上⾯写的compute
wrap_topi_ic.schedule_extern),
name="ic",
)
return strategy
@ister(["cuda", "gpu"])
def cumsum_strategy_cuda(attrs, inputs, out_type, target):
"""cumsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cuda.cumsum),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumsum.cuda",
)
return strategy
对于每个strategy,与对应的compute、schedule通过add_implementation关联起来。
这⾥的shape_func时对输⼊时动态shape⼚家推导有⽤。
# cumsum
@_ister_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, lusive)]
_ister_strategy("cumsum", strategy.cumsum_strategy)
_ister_shape_func("cumsum", False, elemwise_shape_func)
定义C++函数,为新增算⼦⽣成调⽤节点,并为该函数注册 Python API hook
现在我们有⼀个可以调⽤的relay算⼦了,下⼀步就是如何通过relay call node调⽤。这就需要实现⼀个函数,传递相应的参数给对于的relay算⼦,并且返回对应算⼦的Call Node(这个算⼦最终在Relay表达式的AST⾥⾯)。
当前不⽀持直接调⽤ Attrs和参数。所以需要在函数中构造对应的AttrsNode,传递给对应的Call Node。
Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) {
auto attrs = make_object<ScanopAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
static const Op& op = Op::Get("cumsum");
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum);
Op::Get("cumsum")的实现如下。具体怎么注册到OpRegistry的,TODO
const Op& Op::Get(const String& name) {
const OpRegEntry* reg = OpRegistry::Global()->Get(name);
ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered";
return reg->op();
}
这⾥看⼀下Call的实现,实际上是得到⼀个call Node,⾥⾯保存了算⼦及其属性信息。
Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
ObjectPtr<CallNode> n = make_object<CallNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
n->span = std::move(span);
data_ = std::move(n);
}
Op::Get src/relay/op/
相关接⼝暴露到python侧,是通过.TVM_REGISTER_GLOBAL MakeCumsum MakeCumprod relay.op._make.cumsum(...)relay.op._make.cumsum(...)实现的。细节TODO
将上⾯的 Python API hook 封装成简洁的调⽤⽅式
为更⽅便的使⽤,通常的做法是构造单独的函数,因此最好封装成更简洁的python接⼝。教程的例⼦,定义在
TVM_REGISTER_GLOBAL python/tvm/relay/op/transform.py
def cumsum(data, axis=None, dtype=None, exclusive=None):
return _make.cumsum(data, axis, dtype, exclusive)
def cumprod(data, axis=None, dtype=None, exclusive=None):
return _make.cumprod(data, axis, dtype, exclusive)
特别的,如果不定参数的,需要包成Tuple形式进⾏传递。
def concat(*args):
"""Concatenate the input tensors along the zero axis.
Parameters
----------
args: list of Tensor
Returns
-------
tensor: The concatenated tensor.
"""
tup = Tuple(list(args))
return _at(tup)
为新的relay 算⼦编写测试
参考 tests/python/relay/test_op_level3.py
ref:

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。