tvm.tir
张量级 IR 的命名空间。
class tvm.tir.Buffer
TVM 中的符号数据缓冲区。
缓冲区提供了一种在 TVM 中表示数据结构的数据布局特化的方法。
不要直接构造,而是使用 decl_buffer()
。有关更多详细信息,请参阅文档decl_buffer()
。
decl_buffer
声明缓冲区。
access_ptr(access_mask, ptr_type='handle', content_lanes=1, offset=0, extent=None)
获取缓冲区头部的访问指针。
这是与外部函数交互时获取缓冲区数据 ptress 的推荐方法。
- 参数:
示例
# 获取用于读取的访问指针
buffer.access_ptr("r")
# 获取带位掩码的读/写访问指针
buffer.access_ptr(Buffer.READ | Buffer.WRITE)
# 获取带字符串标志的读/写访问指针
buffer.access_ptr("rw")
# 获取带偏移量的读取访问指针
buffer.access_ptr("r", offset = 100)
# 获取带范围(extent)的读取访问指针
buffer.access_ptr("r", extent = 100)
vload(begin, dtype=None, predicate=None)
生成一个从开始索引加载 dtype 的 Expr。
- 参数:
- 返回:load:相应的负载表达式。
- 返回类型: Expr。
vstore(begin, value, predicate=None)
生成一个将值存储到开始索引的 Stmt。
- 参数:
- 返回:store:相应的存储语句。
- 返回类型:Stmt。
scope()
返回与此缓冲区关联的存储作用域。:returns: scope:与此缓冲区关联的存储范围。:rtype: str
get_flattened_buffer()
生成一个该缓冲区的扁平化版本的 Buffer。
- 返回:flattened:相应的平面缓冲区。
- 返回类型:Buffer。
offset_of(indices)
确定扁平缓冲区中提供的索引的偏移量。
- 参数:indices (Union*[PrimExpr,*** List*[PrimExpr]]*): 原始缓冲区中元素的索引。
- 返回:flattened_indices:扁平缓冲区中元素的偏移索引。
- 返回类型: List[PrimExpr]。
tvm.tir.decl_buffer(shape, dtype=None, name='buffer', data=None, strides=None, elem_offset=None, scope='', data_alignment=-1, offset_factor=0, buffer_type='', axis_separators=None, span=None)
声明一个新的符号缓冲区。
通常,缓冲区在下拉和构建过程中会自动创建。只有当用户想要指定自己的缓冲区布局时才需要这样做。
有关缓冲区使用的详细讨论,请参阅下面的注释。
- 参数:
- shape (tupleofExpr):缓冲区的形状。
- dtype (str,optional):缓冲区的数据类型。
- name (str,optional):缓冲区的名称。
- data (tir.Var,optional):缓冲区中的数据指针。
- strides (arrayofExpr):缓冲区*的步幅。
- elem_offset (Expr,optional): 数组到数据的起始偏移量。以 dtype 元素的数量表示。
- scope (str,optional): 缓冲区的存储范围(如果不是全局的)。如果 scope 等于空字符串,则表示它是全局内存。
- data_alignment (int,optional):1,则对齐方式将设置为 TVM 的内部默认值。
- offset_factor (int,optional):elem_offset 字段的因子。设置后,elem_offset 必须是 offset_factor 的倍数。如果传入 0,则对齐方式将设置为 1。如果传入非零值,且 elem_offset 不为 None,我们将为 elem_offset 创建一个 tir.Var。
- buffer_type (str,optional**,** {""**,"auto_broadcast"}):> buffer[i][0][k]。
- axis_separators (listofint,optional):如果传递,则为轴组之间的分隔符列表,每个轴都会展平为一个输出轴。对于平坦的内存空间,应为 None 或空列表。
- span (Optional[Span]):decl_buffer 在源中创建的位置。
- 返回:buffer:创建的缓冲区。
- 返回类型:tvm.tir.Buffer。
缓冲区数据结构反映了 dlpack 中的 DLTensor 结构。虽然 DLTensor 数据结构非常通用,但创建仅处理特定数据结构情况的函数,并使编译后的函数从中受益通常很有帮助。
如果用户在构造函数时传入 strides 并将 elem_offset 设置为 None,则该函数将针对紧凑且对齐的 DLTensor 进行特化。如果用户将完全通用的符号数组传递给 strides,则生成的函数将变为完全通用的函数。
class tvm.tir.DataProducer
class tvm.tir.Layout
布局由大写字母、小写字母和数字组成,其中大写字母表示主轴,对应的小写字母及其因子大小表示从轴。例如,NCHW16c 可以描述一个 5 维张量,其大小为 [batch_size, channel, height, width, channel_block]。其中,从轴 channel_block=16 表示主轴 C(通道)的因子大小。
layout
声明布局。
index_of(axis)
获取轴的索引。
factor_of(axis)
获取从属轴的因子大小。
- 参数:axis (str):轴名称,需要为 [az,AZ]
- 返回:factor:轴的从属轴的大小(如果轴是主轴),或轴本身的大小(如果轴是从属轴)。如果轴不在布局中,则返回 -1。
- 返回类型:int。
class tvm.tir.BijectiveLayout
两种布局(源布局和目标布局)的双射映射。它提供彼此之间的形状和索引转换。
不要直接构造,而是使用 bijective_layout。有关更多详细信息,请参阅文档 bijective_layout。
bijective_layout
声明布局。
forward_index(index)
给定 src-layout 的索引,推断 dst 索引。
backward_index(index)
给定 dst-layout 的索引,推断 src 索引。
forward_shape(shape)
给定 src-layout 的形状,推断 dst 的形状。
backward_shape(shape)
给定 dst-layout 的形状,推断 src 的形状。
tvm.tir.bijective_layout(src_layout:str|Layout, dst_layout:str|Layout) → BijectiveLayout
创建双射布局映射。
- 参数:
- 返回:bijective_layout: 创建的双射布局。
- 返回类型:BijectiveLayout。
tvm.tir.layout(layout_str:str, dtype:str= 'int32') → Layout
从字符串创建布局节点。
- 参数:
- 返回:layout:创建的布局。
- 返回类型:Layout。
class tvm.tir.Var(name:str, dtype:str|Type, span:Span|None= None)】
符号变量。
class tvm.tir.SizeVar(name:str, dtype:str|Type, span:Span|None= None)
表示张量索引大小的符号变量。
大于或等于零。
class tvm.tir.Reduce(combiner:CommReducer, src:List[PrimExpr], rdom:List[IterVar], condition:PrimExpr, value_index:int, init:List[PrimExpr] |None= None, span:Span|None= None)
归约节点。
- 参数:
class tvm.tir.FloatImm(dtype:str, value:float, span:Span|None= None)
浮点常数。
class tvm.tir.IntImm(dtype:str, value:int, span:Span|None= None)
整数常量。
class tvm.tir.StringImm(value:str, span:Span|None= None)
字符串常量。
class tvm.tir.Cast(dtype, value, span:Span|None= None)
转换表达式。
class tvm.tir.Add(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Add 节点。
class tvm.tir.Sub(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Sub 节点。
class tvm.tir.Mul(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Mul 节点。
class tvm.tir.Div(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Div 节点。
class tvm.tir.Mod(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Mod 节点。
class tvm.tir.FloorDiv(a:PrimExpr, b:PrimExpr, span:Span|None= None)
FloorDiv 节点。
class tvm.tir.FloorMod(a:PrimExpr, b:PrimExpr, span:Span|None= None)
FloorMod 节点。
class tvm.tir.Min(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Min 节点。
class tvm.tir.Max(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Max 节点。
class tvm.tir.EQ(a:PrimExpr, b:PrimExpr, span:Span|None= None)
EQ 节点。
class tvm.tir.NE(a:PrimExpr, b:PrimExpr, span:Span|None= None)
NE 节点。
class tvm.tir.LT(a:PrimExpr, b:PrimExpr, span:Span|None= None)
LT 节点。
class tvm.tir.LE(a:PrimExpr, b:PrimExpr, span:Span|None= None)
LE 节点。
class tvm.tir.GT(a:PrimExpr, b:PrimExpr, span:Span|None= None)
GT 节点。
class tvm.tir.GE(a:PrimExpr, b:PrimExpr, span:Span|None= None)
GE 节点。
class tvm.tir.And(a:PrimExpr, b:PrimExpr, span:Span|None= None)
And 节点。
class tvm.tir.Or(a:PrimExpr, b:PrimExpr, span:Span|None= None)
Or 节点。
class tvm.tir.Not(a:PrimExpr, span:Span|None= None)
Not 节点。
class tvm.tir.Select(condition:PrimExpr, true_value:PrimExpr, false_value:PrimExpr, span:Span|None= None)
Select 节点。
Select 可能会同时计算 true_value 和 false_value。如果您只想获取仅评估正确分支的条件表达式,请使用 tvm.tir.if_then_else
。
- 参数:
class tvm.tir.BufferLoad(buffer:Buffer, indices:List[PrimExpr], predicate:PrimExpr|None= None, span:Span|None= None)
BufferLoad 节点。
- 参数:
class tvm.tir.ProducerLoad(producer:DataProducer, indices:List[PrimExpr], span:Span|None= None)
ProducerLoad 节点。
- 参数:
- producer (DataProducer): 要加载的缓冲区。
- indices (List[PrimExpr]):缓冲区索引。
- span (Optional[Span]):此表达式在源代码中的位置。
class tvm.tir.Ramp(base:PrimExpr, stride:PrimExpr, lanes:PrimExpr, span:Span|None= None)
Ramp 节点。
- 参数:
class tvm.tir.Broadcast(value:PrimExpr, lanes:PrimExpr, span:Span|None= None)
Broadcast 节点。
class tvm.tir.Shuffle(vectors:List[PrimExpr], indices:List[PrimExpr], span:Span|None= None)
Shuffle 节点。
class tvm.tir.Call(dtype:str, op:Op|str, args:List[PrimExpr], span:Span|None= None)
tir.Call 节点。
- 参数:
class tvm.tir.CallEffectKind
可能的 tir.Call 效果种类。
class tvm.tir.Let(var:Var, value:PrimExpr, body:PrimExpr, span:Span|None= None)
Let 节点。
- 参数:
class tvm.tir.IterVar(dom:Range, var:Var|str, iter_type:int, thread_tag:str= '', span:Span|None= None)
表示迭代变量。
IterVar 表示计算中的轴迭代。
- 参数:
te.thread_axis
创建线程轴 IterVar。
te.reduce_axis
创建归约轴 IterVar。
class tvm.tir.CommReducer(lhs:List[Var], rhs:List[Var], result:List[PrimExpr], identity_element:List[PrimExpr], span:Span|None= None)
交换约简运算符。
- 参数:
class tvm.tir.Stmt
所有语句的基类。
class tvm.tir.LetStmt(var:Var, value:PrimExpr, body:Stmt, span:Span|None= None)
LetStmt 节点。
- 参数:
class tvm.tir.AssertStmt(condition:PrimExpr, message:PrimExpr, body:Stmt, span:Span|None= None)
AssertStmt 节点。
- 参数:
- condition (PrimExpr):断言条件。
- message (PrimExpr):错误消息。
- body (tvm.tir.Stmt): 主体语句。
- span (Optional[Span]):源代码中 stmt 的位置。
class tvm.tir.ForKind(value)
for 循环的种类。
ForKind 可以改变循环的控制流语义,需要在所有 TIR 传递中考虑它。
class tvm.tir.For(loop_var:Var, min:PrimExpr, extent:PrimExpr, kind:ForKind, body:Stmt, thread_binding:IterVar|None= None, annotations:Mapping[str, Object] |None= None, span:Span|None= None)
For 节点。
- 参数:
class tvm.tir.While(condition:PrimExpr, body:Stmt, span:Span|None= None)
While 节点。
class tvm.tir.BufferStore(buffer:Buffer, value:PrimExpr, indices:List[PrimExpr], predicate:PrimExpr|None= None, span:Span|None= None)
缓冲存储节点。
- 参数:
class tvm.tir.BufferRealize(buffer:Buffer, bounds:List[Range], condition:PrimExpr, body:Stmt, span:Span|None= None)
BufferRealize 节点。
- 参数:
class tvm.tir.Allocate(buffer_var:Var, dtype:str, extents:List[PrimExpr], condition:PrimExpr, body:Stmt, annotations:Mapping[str, Object] |None= None, span:Span|None= None)
Allocate 节点。
- 参数:
class tvm.tir.AllocateConst(buffer_var:Var, dtype:str, extents:List[PrimExpr], data_or_idx:NDArray|int, body:Stmt, annotations:Mapping[str, Object] |None= None, span:Span|None= None)
AllocateConst 节点。
- 参数:
- buffer_var (tir.Var):缓冲区变量。
- dtype (str):缓冲区的数据类型。
- extents (listofExpr): 分配的范围。
- data_or_idx (Union[NDArray,int]):如果是 NDArray,则这是与常量关联的 const 数据。如果是整数,则这是包含 AllocateConst 的 IRModule 的“constants”属性的索引。
- body (Stmt):正文语句。
- annotations (Optional*[Mapping[str,* Object*]]*): 关于分配的附加注解。
- span (Optional[Span]): 源代码中 stmt 的位置。
class tvm.tir.AttrStmt(node: Object, attr_key:str, value:PrimExpr, body:Stmt, span:Span|None= None)
AttrStmt 节点。
- 参数:
class tvm.tir.DeclBuffer(buffer:Buffer, body:Stmt, span:Span|None= None)
DeclBuffer 节点。
- 参数:
class tvm.tir.SeqStmt(seq:List[Stmt], span:Span|None= None)
语句序列。
class tvm.tir.IfThenElse(condition:PrimExpr, then_case:Stmt, else_case:Stmt|None, span:Span|None= None)
IfThenElse 节点。
- 参数:
class tvm.tir.Evaluate(value:PrimExpr, span:Span|None= None)
Evaluate 节点。
tvm.tir.stmt_seq(args:PrimExpr|*Stmt) → SeqStmt
制定语句序列。
tvm.tir.stmt_list(stmt:Stmt) → List[Stmt]
从块中创建 stmt 列表。
class tvm.tir.BufferRegion(buffer:Buffer, region:List[Range])
BufferRegion 节点。
class tvm.tir.MatchBufferRegion(buffer:Buffer, source:BufferRegion)
MatchBufferRegion 节点。
- 参数:
- buffer (Buffer): 目标缓冲区。
- source (BufferRegion): 源缓冲区的区域。
class tvm.tir.Block(iter_vars:List[IterVar], reads:List[BufferRegion], writes:List[BufferRegion], name_hint:str, body:Stmt, init:Stmt|None= None, alloc_buffers:List[Buffer] |None= None, match_buffers:List[MatchBufferRegion] |None= None, annotations:Mapping[str, Object] |None= None, span:Span|None= None)
Block 节点。
- 参数:
- iter_vars (List[IterVar]):块变量。
- reads (List[BufferRegion]):块的读取缓冲区区域。
- writes (List[BufferRegion]):块的写入缓冲区区域。
- name_hint (str):块的 name_hint。
- body (Stmt):块的主体。
- init (Optional[Stmt]):缩减块的 init 块。
- alloc_buffers (Optional*[list[Buffer]]):缓冲区分配。
- match_buffers (Optional*[List[MatchBufferRegion**]***]):子区域缓冲区匹配。
- annotations (Optional*[Mapping[str,* Object*]]*):额外的注解提示。
- span (Optional[Span]):此块在源代码中的位置。
class tvm.tir.BlockRealize(iter_values:List[PrimExpr], predicate:PrimExpr|bool, block:Block, span:Span|None= None)
BlockRealize 节点。
- 参数:
class tvm.tir.PrimFunc(params, body, ret_type=None, buffer_map=None, attrs=None, span=None)
函数声明表达式。
- 参数:
- params (List[Union[tvm.tir.Var, tvm.tir.Buffer]**]):函数的输入参数列表。
- body (tvm.tir.Stmt):函数主体。
- ret_type (tvm.ir.Type):函数的返回类型注解。
- buffer_map (Map[tvm.tir.Var,tvm.tir.Buffer]):缓冲区绑定图。
- attrs (Optional[tvm.Attrs]):函数的属性,可以为 None。
- span (Optional[Span]):此 itervar 在源代码中的位置。
with_body(new_body, span=None)
创建具有相同集合签名但具有新主体的新 PrimFunc。
- 参数:
- 返回:new_func:创建的新函数。
- 返回类型:PrimFunc。
specialize(param_map:Mapping[Var,PrimExpr|Buffer])
PrimFunc 的专门参数。
示例
我们可以定义一个具有符号形状的 Meta TIR 函数:
@T.prim_func
def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
A = T.match_buffer(a, (m, n), "float32")
B = T.match_buffer(b, (m, n), "float32")
for i, j in T.grid(m, n):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
然后我们可以利用给定的形状或缓冲区使其特化。
a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# 或者
func = mem_copy.specialize({n: 16, m: 16})
专门的函数:
@T.prim_func
def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
B = T.match_buffer(b, (16, 16), "float32")
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj]
- 返回:func: 带有特殊参数的新函数。
- 返回类型:PrimFunc。
class tvm.tir.TensorIntrin(desc, impl)
张量的内在函数。
static register(name:str, desc:PrimFunc, impl:PrimFunc, override:bool= False)
使用其名称注册张量内在函数。
- 参数:
static get(name:str, allow_missing:bool= False) → TensorIntrin | None
通过名称查找张量内在函数。
- 参数:
- 返回:result:具有指定名称的 TensorIntrin,如果未找到则为 None。
- 返回类型: Optional[TensorIntrin]。
class tvm.tir.IndexMap(initial_indices, final_indices, inverse_index_map)
从多维索引到另一组多维索引的映射。
- 参数:
static from_func(mapping_function:Callable, ndim:int|None= None, inverse_index_map:Callable|IndexMap|None= None,*, index_dtype: str = 'int64')
从函数创建索引图。
- 参数:
- mapping_function (Callable):从源索引映射到目标索引的函数。该函数应接受 tir.Var 参数并返回 tir.PrimExpr 或 tir.PrimExpr 列表。返回 tir.PrimExpr 相当于返回包含该 tir.PrimExpr 的长度为 1 的列表。
- ndim (Optional[int]):此转换应应用到的缓冲区的维数。如果 mapping_function 使用可变参数*args,则必须指定 ndim。如果 mapping_function 不使用可变参数,则 ndim 为可选。
- inverse_index_map (Union*[Callable,* Optional*[IndexMap]]*):可选的预定义逆索引图。定义此方法后,IndexMap::Inverse 将返回预定义的逆索引图。否则,逆索引图将即时计算。用户有责任确保预定义逆索引图的正确性。
- 返回:index_map:返回表示 mapping_function 的 IndexMap 。
- 返回类型:IndexMap。
static from_func_with_separators(mapping_function:Callable, ndim:int|None= None, inverse_index_map:Callable|IndexMap|None= None, ***, index_dtype: str = 'int64')
从函数创建索引图。
- 参数:
- mapping_function (Callable):用于从源索引映射到目标索引的函数。该函数应接受 tir.Var 参数并返回 tir.PrimExpr 或列表。返回列表的每个元素应为 tir.PrimExpr 或 IndexMap.AXIS_SEPARATOR 对象。返回 tir.PrimExpr 相当于返回包含该 tir.PrimExpr 的长度为 1 的列表。
- ndim (Optional[int]):此转换应应用到的缓冲区的维数。如果 mapping_function 使用可变参数*args,则必须指定 ndim 。如果 mapping_function 不使用可变参数,则 ndim 为可选。
- inverse_index_map (Union*[Callable,* Optional*[IndexMap]]*): 可选的预定义逆索引图。定义此方法后,IndexMap::Inverse 将返回预定义的逆索引图。否则,逆索引图将即时计算。用户有责任确保预定义逆索引图的正确性。
- index_dtype (str): 映射函数中输入迭代器使用的默认索引 dtype。
- 返回:ret:返回一个元组,其第一个元素是表示**mapping_function 的 IndexMap ,其第二个索引是 IndexMap.AXIS_SEPARATOR 发生的索引列表 。
- 返回类型:Tuple[IndexMap, List[int]]。
is_equivalent_to(other_map:IndexMap) → bool
如果索引图等效,则返回。
- 参数:other_map (IndexMap):应该进行比较的 IndexMap。
- 返回:is_equivalent:如果两个映射表示相同的转换,则为 True,否则为 False。
- 返回类型:bool。
map_indices(indices:List[PrimExpr]) → List[PrimExpr]
将索引图应用于一组索引。
map_shape(shape:List[PrimExpr]) → List[PrimExpr]
将索引图应用于缓冲区形状。
map_ndarray(arr_src:NDArray) → NDArray
将此索引映射应用于输入 NDArray,以转换其布局。
- 参数:arr_src (runtime.NDArray):要转换的 NDArray。
- 返回:arr_dst:转换后的 NDArray。
- 返回类型: runtime.NDArray。
inverse(shape:List[Range|PrimExpr]) → IndexMap
返回该映射的逆映射。
如果该函数不是双射(bijective),则会抛出错误。
- 参数:shape (List*[Union[Range,PrimExpr**]***]):需要确定逆的区域。用于验证映射在此范围内是否为双射。
- 返回:inverse:逆。
- 返回类型:IndexMap。
non_surjective_inverse(shape:List[Range|PrimExpr]) → Tuple[IndexMap, PrimExpr]
返回该映射的逆映射。
可用于引入填充(padding)的变换。
- 参数:shape (List*[Union[Range,PrimExpr**]***]):需要确定逆的区域。用于确定谓词。
- 返回:result:逆,以及逆映射到输入范围中的有效索引的谓词。
- 返回类型:Tuple[IndexMap, PrimExpr]。
示例
index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # 打印 "(axis0==3) && (axis2 >= 2)"
tvm.tir.call_packed_lowered(args*, span=None)
调用 packed 的低版本。packed 函数的参数可以是 Expr 或 Buffer。当传入 Expr 时,参数为对应的 POD 类型。当参数为 Buffer 时,对应的 PackedFunc 将收到一个 TVMArrayHandle,其内容在回调期间有效。如果 PackedFunc 是 Python 回调,则对应的参数为 NDArray。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
te.extern
使用外部函数调用创建张量。
tvm.tir.call_cpacked_lowered(args*, span=None)
call c-packed 的低版本。与 call_packed 相同,但第一个参数是函数名(与 call_extern 类似),最后一个参数是资源句柄。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
te.extern
使用外部函数调用创建张量。
tvm.tir.call_tir(global_var:GlobalVar, args*)
调用同一 IRModule 中的另一个 PrimFunc。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.call_packed(args*, span=None)
通过调用外部打包函数来构建表达式。
打包函数的参数可以是 Expr 或 Buffer。Expr 为参数时,参数为对应的 POD 类型。
当参数为 Buffer 时,对应的 PackedFunc 会收到一个 TVMArrayHandle,其内容在回调期间有效。如果 PackedFunc 是 python 回调,则对应的参数为 NDArray。
- 参数:
- 返回:call: 调用表达式。
- 返回类型:PrimExpr。
te.extern
使用外部函数调用创建张量。
tvm.tir.call_cpacked(args*, span=None)
通过调用外部打包函数来构建表达式。
与 call_packed 相同,但第一个参数是函数名(如 call_extern 中一样),最后一个参数是资源句柄。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
te.extern
使用外部函数调用创建张量。
tvm.tir.call_intrin(dtype, func_name, args*, span=None)
通过调用内部函数来构建表达式。
内部函数可以通过内部转换规则使用多种数据类型进行重载。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.call_pure_extern(dtype, func_name, args*, span=None)
通过调用纯外部函数来构建表达式。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.call_extern(dtype, func_name, args*, span=None)
通过调用外部函数来构建表达式。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.call_llvm_intrin(dtype, name, args*, span=None)
通过调用 llvm 内部函数来构建表达式
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.call_llvm_pure_intrin(dtype, name, args*, span=None)
通过调用纯 llvm 内部函数来构建表达式
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr
tvm.tir.ret(val, span=None)
创建 tir 返回表达式
- 参数:
- val (Expr):返回的 tir 表达式,其数据类型为 int、float 或 void 指针。
- span (Optional[Span]):此运算符在源代码中的位置。
- 返回:ret:返回表达式。
- 返回类型:PrimExpr。
tvm.tir.all(args*, span=None)
创建一个新的表达式,该表达式表示所有参数条件的交集。
tvm.tir.any(args*, span=None)
创建一个新的表达式,表示所有参数条件的并集。
tvm.tir.min_value(dtype, span=None)
dtype 的最小值。
tvm.tir.max_value(dtype:str, span:Span|None= None) → Any
dtype 的最大值
tvm.tir.trace(args, trace_action='tvm.default_trace_action')
在运行时跟踪张量数据。
trace 函数允许在运行时跟踪特定的张量。跟踪值应作为最后一个参数。应指定跟踪操作,默认情况下使用 tvm.default_trace_action。
tvm.tir.call_packed
创建打包函数。
tvm.tir.tvm_stack_alloca(dtype_str, num)
返回堆栈上的新 dtype[num]。
tvm.tir.tvm_stack_make_shape(args*)
在堆栈上分配一个形状元组,返回句柄。
tvm.tir.tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset)
在堆栈上分配一个 NDArray(DLTensor),返回句柄。
- 参数:
- data (Expr): 数组的数据。
- shape (Expr):数组的形状。
- strides (Expr):数组的步幅。
- ndim (Expr): 数组的维度。
- arr_dtype (Expr): 数组的数据类型。
- elem_offse (Expr): 数组的元素偏移量。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_tuple(value*)
在 AttrStmt 的值字段中创建一个元组结构。
tvm.tir.handle_add_byte_offset(handle, offset)
为句柄添加偏移量。
tvm.tir.tvm_struct_get(arr, index, field, dtype)
获取数组中的结构字段值。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_struct_set(arr, index, field, value)
设置数组中结构字段的值。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.address_of(obj:Buffer|BufferLoad, span:Span|None= None) → PrimExpr
返回缓冲区中元素的地址。
- 参数:
- obj (Union[Buffer,BufferLoad]):缓冲区或缓冲区负载。
- span (Optional[Span]):此运算符在源代码中的位置。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.lookup_param(param_name, span=None)
按名称返回参数。
tvm.tir.assume(cond=None)
提供可用于简化的真实陈述。
- 参数:cond (Expr):约束条件。
- 返回:call:调用表达式。
- 返回类型:PrimExpr
tvm.tir.undef()
返回已初始化的任意值。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_thread_allreduce(freduce_args*)
在线程块内执行 allreduce。
- 参数:freduce_args (Expr):参数。
- 返回:call:调用表达式。
- 返回类型:PrimExpr
tvm.tir.type_annotation(dtype)
创建类型注解表达式。
- 参数:dtype (Expr):数据类型。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
通过内存访问模式信息获取头部访问地址。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_throw_last_error()
抛出 TVMGetLastError()。
- 返回:ret:返回表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
TVM 张量核心负载运算符的内在函数。
- 参数:
- fragment (tir.Var): wmma 片段。
- m (UIntImm):wmma 片段的形状。
- n (UIntImm):wmma 片段的形状。
- k (UIntImm):wmma 片段的形状。
- index (Expr):片段索引。
- buffer_ptr (Expr):片段缓冲区指针。
- stride (Expr):片段步幅。
- layout (Literal*["row_major", *"column_major"]):片段布局。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
TVM 张量核心存储运算符的内在函数。
- 参数:
- fragment (tir.Var): wmma 片段。
- m (UIntImm):wmma 片段的形状。
- n (UIntImm):wmma 片段的形状。
- k (UIntImm):wmma 片段的形状。
- index (Expr):片段索引。
- buffer_ptr (Expr):片段缓冲区指针。
- stride (Expr):片段步幅。
- layout (Literal*["row_major", *"column_major"]):片段布局。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
TVM 张量核心 mma_sync 运算符的内在函数。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
张量核心 bmma_sync 运算符的 TVM 内在函数。
- 参数:
- fragment_d (tir.Var) – bwmma 片段_d。
- index_d (Expr) – 片段 d 的索引。
- fragment_a (tir.Var) – bwmma 片段_a。
- index_a (Expr):fragment_a 索引。
- fragment_b (tir.Var) – bwmma 片段_b。
- index_b (Expr) – 片段_b 的索引。
- fragment_c (tir.Var) – bwmma 片段_c。
- index_c (Expr) – 片段_c 的索引。
- 返回:call : 调用表达式。
- 返回类型:PrimExpr。
tvm.tir.tvm_fill_fragment(fragment, m, n, k, index, value)
TVM 张量核心 fill_fragment 运算符的内在函数。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_mma(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, saturate, operator=None)
TVM 内在的 ptx 张量核心 mma 指令 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma。
- 参数:
- dtype (str): 结果的数据类型。
- shape (str): mma 片段的形状。
- A_layout (Literal*["row", *"col"]):被乘数片段 A 的布局。
- B_layout (Literal*["row", *"col"]):被乘数片段 B 的布局。
- A_dtype (str):被乘数片段 A 的数据类型。
- B_dtype (str): 被乘数片段 B 的数据类型。
- C_dtype (str):累加器片段 C 的数据类型。
- multiplicand_a (tir.Var):被乘数片段 A 变量。
- a_index (Expr): 被乘数片段 A 的索引。
- multiplicand_b (tir.Var):被乘数片段 B 变量。
- b_index (Expr):被乘数片段 A 的索引。
- accumulator (tir.Var):累加器片段 C 变量。
- c_index (Expr):累加器片段 C 的索引。
- saturate (bool):输出处的可选饱和度。
- operator (Optional*[Literal[****"xor",*** "and"**]]):1 位运算符。
- 返回:call: 调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_mma_sp(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index, metadata, meta_index, sparse_selector, saturate)
TVM 稀疏张量核心 ptx 指令的内在函数 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma。
- 参数:
- dtype (str):结果的数据类型。
- shape (str):mma 片段的形状。
- A_layout (Literal*["row", *"col"]):被乘数片段 A 的布局。
- B_layout (Literal*["row", *"col"]):被乘数片段 B 的布局。
- A_dtype (str):被乘数片段 A 的数据类型。
- B_dtype (str):被乘数片段 B 的数据类型。
- C_dtype (str):被乘数片段 C 的数据类型。
- multiplicand_a (tir.Var):被乘数片段 A 变量。
- a_index (Expr): 被乘数片段 A 的索引。
- multiplicand_b (tir.Var):被乘数片段 B 变量。
- b_index (Expr):被乘数片段 B 的索引。
- accumulator (tir.Var): 累加器片段 C 变量。
- c_index (Expr):累加器片段 C 的索引。
- metadata (Expr):操作数的元数据。
- meta_index (Expr):操作数的元数据索引。
- sparse_selector (Expr):指示存储元数据的线程的稀疏选择器。
- saturate (bool):输出处的可选饱和度。
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride)
TVM 内部函数,用于将 PTX MMA 的结果存储到目标指针中。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.mma_fill(dtype, local_size, local_ptr, offset)
TVM 内在函数,用于对 MMA 累积寄存器进行零初始化。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)
TVM 内部函数,用于从共享内存中加载 ptx 矩阵 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix。
- 参数:
- 返回:call: 调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)
TVM 内部使用 cp.async 将 ptx 异步复制到从全局到共享内存 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr
tvm.tir.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)
TVM 使用 cp.async.bulk 将 ptx 异步复制到从全局到共享内存的内在 函数 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_commit_group()
TVM ptx 异步复制提交内在函数 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group。
tvm.tir.ptx_wait_group(num)
TVM 内部用于 ptx 异步复制等待 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group。
tvm.tir.ptx_cp_async_barrier(barrier_id)
TVM 使用 cp.async.mbarrier.arrive 实现 ptx 异步复制屏障的内在机制 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive。
tvm.tir.ptx_init_barrier_thread_count(barrier_id, thread_count)
TVM 使用 mbarrier.init 来初始化线程数的 ptx 屏障 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init。
tvm.tir.ptx_arrive_barrier(barrier_id)
TVM 使用 mbarrier.arrive 实现 ptx 屏障到达的内在机制 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive。
tvm.tir.ptx_arrive_barrier_expect_tx(barrier_id, byte_count)
TVM 内在函数,用于使用 mbarrier.arrive.expect_tx 实现 ptx 屏障到达并期望 tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.ptx_wait_barrier(barrier_id)
TVM 使用 mbarrier.try_wait 等待 ptx 屏障的内在机制 https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait。
tvm.tir.create_barriers(barrier_count)
TVM 固有创建 N 个屏障。
tvm.tir.make_filled_simdgroup_matrix(d:Var, index:PrimExpr, value:PrimExpr, col:int= 8, row:int= 8)
创建填充的 SIMDGroup 矩阵。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.simdgroup_load(d:Var, index:PrimExpr, ptr:PrimExpr, stride:PrimExpr, col:int= 8, row:int= 8, transpose_matrix:bool= False)
将数据从设备内存或线程组内存加载到 simdgroup。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.simdgroup_multiply_accumulate(d:Var, index_d:PrimExpr, a:Var, index_a:PrimExpr, b:Var, index_b:PrimExpr, c:Var, index_c:PrimExpr)
在 simdgroup 中对两个矩阵进行乘法和累加,即 d = a * b + c。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.simdgroup_store(d:PrimExpr, index:PrimExpr, ptr:PrimExpr, stride:PrimExpr, col:int= 8, row:int= 8, transpose_matrix:bool= False)
将数据从 simdgroup 存储到设备内存或线程组内存。
-
参数:
-
transpose_matrix:bool
是否转置矩阵。
-
返回:call:调用表达式。
-
返回类型:PrimExpr。
tvm.tir.vectorlow(dtype, vec)
获取向量的低位一半。
tvm.tir.vectorhigh(dtype, vec)
获取向量的高位一半。
tvm.tir.vectorcombine(dtype, vec1, vec2)
连接两个向量。
tvm.tir.infinity(dtype:str, span:Span|None= None) → Any
数据类型的无穷大值。
tvm.tir.reinterpret(dtype, value, span:Span|None= None) → Any
数据类型的无穷大值。
- 参数:
- 返回:value:重新解释 dtype 的转换值。
- 返回类型: tvm.Expr。
tvm.tir.exp(x)
取输入 x 的指数。
tvm.tir.exp2(x)
计算 2**x。
tvm.tir.exp10(x)
计算 10**x。
tvm.tir.log(x)
对输入 x 取对数。
tvm.tir.log2(x)
对输入 x 取 log2。
tvm.tir.log10(x)
对输入 x 取 log10。
tvm.tir.log1p(x)
对输入 x 取 log(x + 1)。
tvm.tir.ldexp(x1, x2)
返回 x1 * (2 ** x2)。
tvm.tir.clz(x)
计算整数 x 的前导零位。
tvm.tir.sin(x)
对输入 x 取正弦值。
tvm.tir.sinh(x)
对输入 x 取 sinh。
tvm.tir.asin(x)
取输入 x 的 asin。
tvm.tir.asinh(x)
取输入 x 的正弦值。
tvm.tir.cos(x)
取输入 x 的 cos。
tvm.tir.cosh(x)
对输入 x 取余弦值。
tvm.tir.acos(x)
对输入 x 取余数。
tvm.tir.acosh(x)
对输入 x 取余数。
tvm.tir.tan(x)
对输入 x 取 tan。
tvm.tir.tanh(x)
对输入 x 取双曲 tanh。
tvm.tir.atan(x)
对输入 x 取正切值。
tvm.tir.atan2(x1, x2)
取 arctan2(x1, x2)。
tvm.tir.atanh(x)
对输入 x 进行 atanh 处理。
tvm.tir.bitwise_and(x, y, span=None)
对两个值进行按位与运算。
- 参数:
- 返回:res:结果。
- 返回类型:PrimExpr。
tvm.tir.bitwise_not(x, span=None)
对输入值进行按位非。
tvm.tir.bitwise_or(x, y, span=None)
对两个值进行按位或操作。
- 参数:
- 返回:res:结果。
- 返回类型:PrimExpr。
tvm.tir.bitwise_xor(x, y, span=None)
对两个值进行按位异或。
- 参数:
- 返回:res:结果。
- 返回类型:PrimExpr。
tvm.tir.erf(x)
取输入 x 的高斯误差函数。
tvm.tir.sigmoid(x)
快速获取 S 形函数。
tvm.tir.sqrt(x)
对输入 x 取平方根。
tvm.tir.rsqrt(x)
取输入 x 的平方根的倒数。
tvm.tir.floor(x: PrimExprWithOp, span=None)
取浮点输入 x 的下限。
tvm.tir.ceil(x, span=None)
对浮点输入 x 取上限。
tvm.tir.hypot(x1, x2)
相当于 sqrt(x12 + x22),逐个元素。
tvm.tir.trunc(x, span=None)
获取输入的截断值。
标量 x 的截断值是最接近的整数 i,它比 x 更接近零。
tvm.tir.abs(x, span=None)
逐个获取输入元素的绝对值。
tvm.tir.round(x, span=None)
将数组元素四舍五入为最接近的整数。
tvm.tir.nextafter(x1, x2)
返回在 x1 和 x2 之间,比 x1 更接近 x2 的下一个浮点数。
tvm.tir.nearbyint(x, span=None)
将数组元素四舍五入为最接近的整数。此内在函数使用 llvm.nearbyint 而不是 llvm.round,后者速度更快,但结果与 te.round 不同。值得注意的是,nearbyint 根据舍入模式进行舍入,而 te.round (llvm.round) 则忽略该模式。有关两者之间的差异,请参阅: https: //en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint。
tvm.tir.power(x, y, span=None)
x 次方 y。
tvm.tir.pow(x, y, span=None)
x 次方 y。
- 参数:
- 返回:z:结果。
- 返回类型:PrimExpr。
tvm.tir.popcount(x)
计算输入 x 中设置位的数量。
tvm.tir.fmod(x, y)
返回 x 除以 y 后的余数,其符号与 x 相同。
tvm.tir.if_then_else(cond, t, f, span=None)
条件选择表达式。
- 参数:
- 返回:result:条件表达式的结果。
- 返回类型:Node
与 Select 不同,if_then_else 不会执行不满足条件的分支。您可以使用它来防止越界访问。与 Select 不同,如果向量中某些通道的条件不同,则 if_then_else 无法进行向量化。
tvm.tir.likely(cond, span=None)
将情况标记为可能。
tvm.tir.isnan(x, span=None)
检查输入值是否为 Nan。
tvm.tir.isnullptr(x, span=None)
检查输入值是否为 nullptr。
tvm.tir.isfinite(x, span=None)
检查输入值是否有限。
tvm.tir.isinf(x, span=None)
检查输入值是否无限。
tvm.tir.copysign(x1, x2)
逐个元素地将 x1 的符号更改为 x2 的符号。
tvm.tir.div(a, b, span=None)
按照 C/C++ 语义计算 a / b。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
当操作数为整数时,返回 truncdiv(a, b, span)。
tvm.tir.indexdiv(a, b, span=None)
计算 floor(a / b),其中 a 和 b 为非负数。
- 参数:
- 返回:res : 结果表达式。
- 返回类型:PrimExpr。
使用此函数拆分非负索引。此函数可以利用操作数的非负性。
tvm.tir.indexmod(a, b, span=None)
计算 indexdiv 的余数。a 和 b 非负。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
使用此函数拆分非负索引。此函数可以利用操作数的非负性。
tvm.tir.truncdiv(a, b, span=None)
计算两个表达式的 truncdiv。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
这是 C 语言中的默认整数除法行为。
tvm.tir.truncmod(a, b, span=None)
计算两个表达式的 truncmod。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
这是 C 语言中的默认整数除法行为。
tvm.tir.floordiv(a, b, span=None)
计算两个表达式的 floordiv。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
tvm.tir.floormod(a, b, span=None)
计算两个表达式的 floormod。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
tvm.tir.ceildiv(lhs, rhs, span=None)
通用 ceildiv 运算符。
- 参数:
- 返回:op:ceildiv 运算的结果 Expr。
- 返回类型: tvm.Expr。
tvm.tir.logaddexp(a, b, span=None)
计算两个表达式的 logaddexp。
- 参数:
- 返回:res:结果表达式。
- 返回类型:PrimExpr。
tvm.tir.comm_reducer(fcombine, fidentity, name='reduce')
创建一个交换减速器用于减速。
- 参数:
- fcombine (function*(***Expr -> Expr -> Expr)):一个二元函数,以两个 Expr 作为输入并返回一个 Expr。
- fidentity (function*(***str -> Expr)):以字符串类型作为输入并返回 const Expr 的函数。
返回:reducer:在 axis 上创建 Reduce 表达式的函数。有两种使用方法:
- accept (expr, axis, where) 来在指定轴上生成一个 Reduce Expr。
- 直接使用多个 Exprs。
- 返回类型: function。
示例
n = te.var("n")
m = te.var("m")
mysum = te.comm_reducer(lambda x, y: x+y,
lambda t: tvm.tir.const(0, dtype=t), name="mysum")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
tvm.tir.min(expr, axis, where=None, init=None, args*)
在轴上创建最小表达式。
- 参数:
- 返回:value:结果值。
- 返回类型:PrimExpr。
示例
m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
# 使用这个最小值归约器(min reducer)有两种方式:
# 方式 1:接受 (expr, axis, where) 参数来生成一个归约表达式(Reduce Expr)
# tvm.min 表示 tvm.te.min 或 tvm.tir.min。
B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B")
# 方式 2:直接用于多个表达式:
min_res = tvm.min(m, n)
tvm.tir.max(expr, axis, where=None, init=None, args*)
在轴上创建最大表达式。
- 参数:
- 返回:value:结果值。
- 返回类型:PrimExpr。
示例
m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
# 使用这个最大值归约器(max reducer)有两种方式:
# 方式 1:接受 (expr, axis, where) 参数来生成一个归约表达式(Reduce Expr)
# tvm.max 表示 tvm.te.max 或 tvm.tir.max。
B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B")
# 方式 2:直接用于多个表达式:
max_res = tvm.max(m, n)
tvm.tir.sum(expr, axis, where=None, init=None, args*)
在轴上创建一个求和表达式。
- 参数:
- 返回:value:结果值。
- 返回类型:PrimExpr。
示例
m = te.var("m")
n = te.var("n")
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
# 使用这个求和归约器(sum reducer)有两种方式:
# 方式 1:接受 (expr, axis, where) 参数来生成一个归约表达式(Reduce Expr)
# tvm.sum 表示 tvm.te.sum 或 tvm.tir.sum。
B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
# 方式 2:直接用于多个表达式:
sum_res = tvm.sum(m, n)
tvm.tir.q_multiply_shift(x, y, q, s)
对两个 Q 数 x 和 y 执行乘法,然后右移 s。数学表达式为:
out = round(xy2^-s) 输出 = 舍入(xy2^-s) 。
有关 Q 数的更多信息,请参见:https://en.wikipedia.org/wiki/Q_(number_format 舍入规则是舍入到最接近的值,将一半向上舍入(即,round(x.1) = x 和 round (x.5) = x+1)。
- 参数:
- 返回:y:结果。
- 返回类型:PrimExpr。
tvm.tir.q_multiply_shift_per_axis(x:PrimExpr, y:PrimExpr, ls:PrimExpr, rs:PrimExpr, q:IntImm, is_lshift_required:IntImm, is_rshift_required:IntImm)
执行两个 Q 数字 x 和 y 之间的乘法。
tvm.tir.shift_left(x, y, span=None)
返回 x 左移 y 位的结果。
tvm.tir.shift_right(x, y, span=None)
返回 x 右移 y 位的结果。
tvm.tir.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)
分配临时工作空间的后端函数。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.TVMBackendFreeWorkspace(device_type, device_id, ptr)
后端函数用于释放临时工作空间。
- 参数:
- 返回:call:调用表达式。
- 返回类型:PrimExpr。
tvm.tir.start_profile_intrinsic(id)
启动配置文件内在。:param id:内在 id。:type id:int。
- 返回:call: 调用表达式。
- 返回类型:PrimExpr。
tvm.tir.end_profile_intrinsic(id)
结束配置文件内在:param id:内在 id:type id:int。
- 返回:call*:*调用表达式。
- 返回类型:PrimExpr。
tvm.tir.vscale()
获取目标的 vscale 值。它将被降低到 llvm.vscale 内部函数 ( https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic ) :returns: call – tir.Call 到 vscale 内部函数 :rtype: PrimExpr。
tvm.tir.get_active_lane_mask(dtype, base, limit)
给定上限(限制)和当前值(基数)计算谓词掩码。
它将被降低到 llvm.get.active.lane.mask 内在函数。(https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)。
tvm.tir.get_vscale_expr(dtype:str| dtype, min_size:int= 128) → PrimExpr
创建依赖于数据类型的可扩展表达式。
tvm.tir.dp4a(vec1, vec2, acc=0)
两个 int8x4 向量的点积并添加一个可选累加器。
tvm.tir.ignore_loop_partition(predicate) → PrimExpr
注解谓词不被视为循环分区的目标条件。
tvm.tir.add(lhs, rhs, span=None)
通用加法运算符。
- 参数:
- 返回:op:加法运算的结果 Expr。
- 返回类型: tvm.Expr。
tvm.tir.subtract(lhs, rhs, span=None)
通用减法运算符。
- 参数:
- 返回:op:减法运算的结果 Expr。
- 返回类型: tvm.Expr。
tvm.tir.multiply(lhs, rhs, span=None)
通用乘法运算符。
- 参数:
- 返回:op:乘法运算的结果 Expr。
- 返回类型: tvm.Expr。
class tvm.tir.BlockDependenceInfo(mod:IRModule|PrimFunc)
使用两个核心对象 BlockScope 和 StmtSRef 帮助构建和查询块级依赖关系的对象。
公开的数据结构包括:1)sref2scope:从 srefs 映射到其对应的 BlockScope 2)stmt2ref:从块映射到对应的 StmtSRefs。
请注意,此对象不存储循环的 SRef,因为其目的仅用于公开块级依赖关系。这带来了一个优势:给定块 sref 的作用域块(父块)可以直接通过 sref->parent 进行访问。
get_sref(block:Block) → StmtSRef | None
返回指向该块的相应 sref
get_block_scope(block_sref:StmtSRef) → BlockScope
获取与块 sref 对应的 BlockScope/。
tvm.tir.build(mod:PrimFunc|IRModule, target:str|Target|None= None, pipeline:None|str|Pass= 'default')
构建一个带有签名的函数,为结合目标信息的设备生成代码。
- 参数:
- 返回: 结合主机和设备代码的模块。
- 返回类型: tvm.runtime.Module。
tvm.tir.get_tir_pipeline(name:str= 'default', **kwargs) → Pass
按名称获取预构建管道。
- 参数:name (Optional[str]):管道的名称。
tvm.tir.get_default_tir_pipeline(target:Target) → Pass
获取给定目标的默认 TIR 管道。
class tvm.tir.PyStmtExprVisitor
Python StmtExprVisitor 用于为 Stmt 和 PrimExpr 定义自定义访问者。
用户可以自定义任意的访问函数。
visit_stmt(stmt:Stmt) → None
访问AttrStmt。
- 参数:stmt (Stmt):要访问的 Stmt。
visit_expr(expr:PrimExpr) → None
访问 PrimExpr。
- 参数:expr (PrimExpr):要访问的 PrimExpr。
visit_attr_stmt_(op:AttrStmt) → None
访问 AttrStmt,用户可以自定义该函数,在 C++ 端覆盖 VisitStmt_(const AttrStmtNode* op)。
- 参数:op (AttrStmt):要访问的 AttrStmt。
visit_if_then_else_(op:IfThenElse) → None
访问 IfThenElse,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const IfThenElseNode* op)。
- 参数:op (IfThenElse):要访问的 IfThenElse。
visit_let_stmt_(op:LetStmt) → None
访问 LetStmt。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const LetStmtNode* op)。
- 参数:op (LetStmt):要访问的 LetStmt。
visit_for_(op:For) → None
访问 For,用户可以自定义该函数,在 C++ 端覆盖 VisitStmt_(const ForNode* op)。
- 参数:op (For):要访问的 For。
visit_while_(op:While) → None
访问 While。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const WhileNode* op)。
- 参数:op (While):需要访问的 While 部分。
visit_allocate_(op:Allocate) → None
Visit Allocate,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const AllocateNode* op)。
- 参数:op (Allocate):要访问的分配。
visit_allocate_const_(op:AllocateConst) → None
访问 AllocateConst,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const AllocateConstNode* op)。
- 参数:op (AllocateConst):要访问的 AllocateConst。
visit_decl_buffer_(op:DeclBuffer) → None
访问 DeclBuffer,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const DeclBufferNode* op)。
- 参数:op (DeclBuffer):要访问的 DeclBuffer。
visit_buffer_store_(op:BufferStore) → None
访问 BufferStore,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BufferStoreNode* op)。
- 参数:op (BufferStore):要访问的 BufferStore。
visit_buffer_realize_(op:BufferRealize) → None
访问 BufferRealize,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BufferRealizeNode* op)。
- 参数:op (BufferRealize):要访问的 BufferRealize。
visit_assert_stmt_(op:AssertStmt) → None
访问 AssertStmt,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const AssertStmtNode* op)。
- 参数:op (AssertStmt):要访问的 AssertStmt。
visit_seq_stmt_(op:SeqStmt) → None
访问 SeqStmt,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const SeqStmtNode* op)。
- 参数:op (SeqStmt):要访问的 SeqStmt。
visit_evaluate_(op:Evaluate) → None
Visit Evaluate,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const EvaluateNode* op)。
- 参数:op (Evaluate):要访问的评估。
visit_block_(op:Block) → None
访问区块。用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BlockNode* op)。
visit_block_realize_(op:BlockRealize) → None
访问 BlockRealize。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const BlockRealizeNode* op)。
- 参数:op (BlockRealize):要访问的 BlockRealize。
visit_var_(op:Var) → None
参观 Var。
用户可以自定义该函数,在 C++端覆盖 VisitVar_(const VarNode* op)。
- 参数:op (tir.Var):要访问的 tir.Var。
visit_size_var_(op:SizeVar) → None
访问 SizeVar。
用户可以自定义该函数,在 C++端覆盖 VisitSizeVar_(const SizeVarNode* op)。
- 参数:op (SizeVar):要访问的 SizeVar。
visit_buffer_load_(op:BufferLoad) → None
访问 BufferLoad。
用户可以自定义该函数,在 C++端覆盖 VisitBufferLoad_(const BufferLoadNode* op)。
- 参数:op (BufferLoad):要访问的 BufferLoad。
visit_producer_load_(op:ProducerLoad) → None
访问 ProducerLoad。
用户可以自定义该函数,在 C++端覆盖 VisitProducerLoad_(const ProducerLoadNode* op)。
- 参数:op (ProducerLoad):要访问的 ProducerLoad。
visit_let_(op:Let) → None
访问 Let。
用户可以自定义该函数,在 C++端覆盖 VisitLet_(const LetNode* op)。
- 参数:op (Let):要访问的 Let。
visit_call_(op:Call) → None
访问调用。
用户可以自定义该函数,在 C++端覆盖 VisitCall_(const CallNode* op)。
- 参数:op (tir.Call):要访问的 tir.Call。
visit_add_(op:Add) → None
访问添加。
用户可以自定义该函数,在 C++端覆盖 VisitAdd_(const AddNode* op)。
- 参数:op (Add):要访问的 Add。
visit_sub_(op:Sub) → None
访问 Sub。
用户可以自定义该函数,在 C++端覆盖 VisitSub_(const SubNode* op)。
- 参数:op (Sub):要访问的 Sub。
visit_mul_(op:Mul) → None
参观 Mul。
用户可以自定义该函数,在 C++端覆盖 VisitMul_(const MulNode* op)。
- 参数:op (Mul):要访问的 Mul。
visit_div_(op:Div) → None
访问 Div。
用户可以自定义该函数,在 C++端覆盖 VisitDiv_(const DivNode* op)。
- 参数:op (Div):要访问的 Div。
visit_mod_(op:Mod) → None
访问 Mod。
用户可以自定义该函数,在 C++端覆盖 VisitMod_(const ModNode* op)。
- 参数:op (Mod):要访问的 Mod。
visit_floor_div_(op:FloorDiv) → None
访问 FloorDiv。
用户可以自定义该函数,在 C++端覆盖 VisitFloorDiv_(const FloorDivNode* op)。
- 参数:op (FloorDiv):要访问的 FloorDiv。
visit_floor_mod_(op:FloorMod) → None
访问 FloorMod。
用户可以自定义该函数,在 C++端覆盖 VisitFloorMod_(const FloorModNode* op)。
- 参数:op (FloorMod):要访问的 FloorMod。
visit_min_(op:Min) → None
访问 Min。
用户可以自定义该函数,在 C++端覆盖 VisitMin_(const MinNode* op)。
- 参数:op (Min):要访问的 Min。
visit_max_(op:Max) → None
访问 Max。
用户可以自定义该函数,在 C++端覆盖 VisitMax_(const MaxNode* op)。
- 参数:op (Max):要访问的最大值。
visit_eq_(op:EQ) → None
访问 EQ。
用户可以自定义该函数,在 C++端覆盖 VisitEQ_(const EQNode* op)。
- 参数:op (EQ):要访问的 EQ。
visit_ne_(op:NE) → None
访问 NE。
用户可以自定义该函数,在 C++端覆盖 VisitNE_(const NENode* op)。
- 参数:op (NE):要访问的 NE。
visit_lt_(op:LT) → None
访问 LT。
用户可以自定义该函数,在 C++端覆盖 VisitLT_(const LTNode* op)。
- 参数:op (LT): 要访问的 LT。
visit_le_(op:LE) → None
访问 LE。
用户可以自定义该函数,在 C++端覆盖 VisitLE_(const LENode* op)。
- 参数:op (LE): 要访问的 LE。
visit_gt_(op:GT) → None
访问 GT。
用户可以自定义该函数,在 C++端覆盖 VisitGT_(const GTNode* op)。
- 参数:op (GT):要访问的 GT。
visit_ge_(op:GE) → None
访问 GE。
用户可以自定义该函数,在 C++端覆盖 VisitGE_(const GENode* op)。
- 参数:op (GE):要访问的 GE。
visit_and_(op:And) → None
访问 And。
用户可以自定义该函数,在 C++端覆盖 VisitAnd_(const AndNode* op)。
- 参数:op (And):要访问的 And。
visit_or_(op:Or) → None
访问 Or。
用户可以自定义该函数,在 C++端覆盖 VisitOr_(const OrNode* op)。
- 参数:op (Or):要访问的 Or。
visit_reduce_(op:Reduce) → None
访问 Reduce。
用户可以自定义该函数,在 C++端覆盖 VisitReduce_(const ReduceNode* op)。
- 参数:op (Reduce):要访问的 Reduce。
visit_cast_(op:Cast) → None
访问 Cast。
用户可以自定义该函数,在 C++端覆盖 VisitCast_(const CastNode* op)。
- 参数:op (Cast):要访问的 Cast。
visit_not_(op:Not) → None
不访问。
用户可以自定义该函数,在 C++端覆盖 VisitNot_(const NotNode* op)。
- 参数:op (Not):不可访问。
visit_select_(op:Select) → None
访问选择。
用户可以自定义该函数,在 C++端覆盖 VisitSelect_(const SelectNode* op)。
- 参数:op (Select):要访问的选择。
visit_ramp_(op:Ramp) → None
参观 Ramp。
用户可以自定义该函数,在 C++端覆盖 VisitRamp_(const RampNode* op)。
- 参数:op (Ramp):要访问的坡道。
visit_broadcast_(op:Broadcast) → None
访问广播。
用户可以自定义该函数,在 C++端覆盖 VisitBroadcast_(const BroadcastNode* op)。
visit_shuffle_(op:Shuffle) → None
访问 Shuffle。
用户可以自定义该函数,在 C++端覆盖 VisitShuffle_(const ShuffleNode* op)。
visit_int_imm_(op:IntImm) → None
访问 IntImm。
用户可以自定义该函数,在 C++端覆盖 VisitIntImm_(const IntImmNode* op)。
- 参数:op (IntImm):要访问的 IntImm。
visit_float_imm_(op:FloatImm) → None
访问 FloatImm。
用户可以自定义该函数,在 C++端覆盖 VisitFloatImm_(const FloatImmNode* op)。
- 参数:op (FloatImm): 要访问的 FloatImm。
visit_string_imm_(op:StringImm) → None
访问 StringImm。
用户可以自定义该函数,在 C++端覆盖 VisitStringImm_(const StringImmNode* op)。
class tvm.tir.PyStmtExprMutator
Python StmtExprMutator 用于为 Stmt 和 PrimExpr 定义自定义变量。
用户可以自定义任意的访问函数。
visit_expr(expr:PrimExpr) → PrimExpr
访问 PrimExpr。用户可以自定义此函数,在 C++ 端覆盖 VisitExpr(const PrimExpr& expr)。
visit_stmt(stmt:Stmt) → Stmt
Visit Stmt,用户可以自定义该函数,在 C++端覆盖 VisitStmt(const Stmt& stmt)。
visit_attr_stmt_(op:AttrStmt) → Stmt
访问 AttrStmt,用户可以自定义该函数,在 C++ 端覆盖 VisitStmt_(const AttrStmtNode* op)。
visit_if_then_else_(op:IfThenElse) → Stmt
访问 IfThenElse,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const IfThenElseNode* op)。
- 参数:op (IfThenElse): 要访问的 IfThenElse。
- 返回:result: 变异的 Stmt。
- 返回类型:Stmt。
visit_let_stmt_(op:LetStmt) → Stmt
访问 LetStmt。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const LetStmtNode* op)。
visit_for_(op:For) → Stmt
访问 For。用户可以自定义该函数,在 C++ 端覆盖 VisitStmt_(const ForNode* op)。
visit_while_(op:While) → Stmt
访问 While。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const WhileNode* op)。
visit_allocate_(op:Allocate) → Stmt
Visit Allocate,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const AllocateNode* op)。
visit_allocate_const_(op:AllocateConst) → Stmt
访问 AllocateConst,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const AllocateConstNode* op)。
- 参数:op (AllocateConst):要访问的 AllocateConst。
- 返回:result:变异的 Stmt。
- 返回类型:Stmt。
visit_decl_buffer_(op:DeclBuffer) → Stmt
访问 DeclBuffer,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const DeclBufferNode* op)。
- 参数:op (DeclBuffer):要访问的 DeclBuffer。
- 返回:result:变异的 Stmt。
- 返回类型:Stmt。
visit_buffer_store_(op:BufferStore) → Stmt
访问 BufferStore,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BufferStoreNode* op)。
- 参数:op (BufferStore):要访问的 BufferStore。
- 返回:result:变异的 Stmt。
- 返回类型:Stmt。
visit_buffer_realize_(op:BufferRealize) → Stmt
访问 BufferRealize,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BufferRealizeNode* op)。
- 参数:op (BufferRealize):要访问的 BufferRealize。
- 返回:result:变异的 Stmt。
- 返回类型:Stmt。
visit_assert_stmt_(op:AssertStmt) → Stmt
访问 AssertStmt,用户可以自定义该函数,在 C++ 端覆盖 VisitStmt_(const AssertStmtNode* op)。
- 参数:op (AssertStmt): 要访问的 AssertStmt。
- 返回:result:变异的 Stmt。
- 返回类型:Stmt。
visit_seq_stmt_(op:SeqStmt) → Stmt
访问 SeqStmt,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const SeqStmtNode* op)。
visit_evaluate_(op:Evaluate) → Stmt
Visit Evaluate,用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const EvaluateNode* op)。
visit_block_(op:Block) → Stmt
访问区块。用户可以自定义该函数,在 C++端覆盖 VisitStmt_(const BlockNode* op)。
visit_block_realize_(op:BlockRealize) → Stmt
访问 BlockRealize。用户可以自定义此函数,在 C++ 端覆盖 VisitStmt_(const BlockRealizeNode* op)。
- 参数:op (BlockRealize):要访问的 BlockRealize。
- 返回:result: 变异的 Stmt。
- 返回类型:Stmt。
visit_var_(op:Var) → PrimExpr
参观 Var。
用户可以自定义该函数,在 C++端覆盖 VisitVar_(const VarNode* op)。
visit_size_var_(op:SizeVar) → PrimExpr
访问 SizeVar。
用户可以自定义该函数,在 C++端覆盖 VisitSizeVar_(const SizeVarNode* op)。
visit_buffer_load_(op:BufferLoad) → PrimExpr
访问 BufferLoad。
用户可以自定义该函数,在 C++端覆盖 VisitBufferLoad_(const BufferLoadNode* op)。
- 参数:op (BufferLoad):要访问的 BufferLoad。
- 返回:result:变异的 PrimExpr。
- 返回类型:PrimExpr。
visit_producer_load_(op:ProducerLoad) → PrimExpr
访问 ProducerLoad。
用户可以自定义该函数,在 C++端覆盖 VisitProducerLoad_(const ProducerLoadNode* op)。
- 参数:op (ProducerLoad):要访问的 ProducerLoad。
- 返回:result:变异的 PrimExpr。
- 返回类型:PrimExpr。
visit_let_(op:Let) → PrimExpr
访问 Let.
用户可以自定义该函数,在 C++端覆盖 VisitLet_(const LetNode* op)。
visit_call_(op:Call) → PrimExpr
访问调用。
用户可以自定义该函数,在 C++端覆盖 VisitCall_(const CallNode* op)。
visit_add_(op:Add) → PrimExpr
访问添加。
用户可以自定义该函数,在 C++端覆盖 VisitAdd_(const AddNode* op)。
visit_sub_(op:Sub) → PrimExpr
访问 Sub。
用户可以自定义该函数,在 C++端覆盖 VisitSub_(const SubNode* op)。
visit_mul_(op:Mul) → PrimExpr
参观 Mul。
用户可以自定义该函数,在 C++端覆盖 VisitMul_(const MulNode* op)。
visit_div_(op:Div) → PrimExpr
访问 Div。
用户可以自定义该函数,在 C++端覆盖 VisitDiv_(const DivNode* op)。
visit_mod_(op:Mod) → PrimExpr
访问 Mod。
用户可以自定义该函数,在 C++端覆盖 VisitMod_(const ModNode* op)。
visit_floor_div_(op:FloorDiv) → PrimExpr
访问 FloorDiv。
用户可以自定义该函数,在 C++端覆盖 VisitFloorDiv_(const FloorDivNode* op)。
visit_floor_mod_(op:FloorMod) → PrimExpr
访问 FloorMod。
用户可以自定义该函数,在 C++端覆盖 VisitFloorMod_(const FloorModNode* op)。
visit_min_(op:Min) → PrimExpr
访问 Min。
用户可以自定义该函数,在 C++端覆盖 VisitMin_(const MinNode* op)。
visit_max_(op:Max) → PrimExpr
拜访 Max。
用户可以自定义该函数,在 C++端覆盖 VisitMax_(const MaxNode* op)。
visit_eq_(op:EQ) → PrimExpr
访问 EQ。
用户可以自定义该函数,在 C++端覆盖 VisitEQ_(const EQNode* op)。
visit_ne_(op:NE) → PrimExpr
访问NE。
用户可以自定义该函数,在 C++端覆盖 VisitNE_(const NENode* op)。
visit_lt_(op:LT) → PrimExpr
访问 LT。
用户可以自定义该函数,在 C++端覆盖 VisitLT_(const LTNode* op)。
visit_le_(op:LE) → PrimExpr
访问 LE。
用户可以自定义该函数,在 C++端覆盖 VisitLE_(const LENode* op)。
visit_gt_(op:GT) → PrimExpr
访问 GT。
用户可以自定义该函数,在 C++端覆盖 VisitGT_(const GTNode* op)。
visit_ge_(op:GE) → PrimExpr
参观 GE。
用户可以自定义该函数,在 C++端覆盖 VisitGE_(const GENode* op)。
visit_and_(op:And) → PrimExpr
访问 And。
用户可以自定义该函数,在 C++端覆盖 VisitAnd_(const AndNode* op)。
visit_or_(op:Or) → PrimExpr
访问 Or。
用户可以自定义该函数,在 C++端覆盖 VisitOr_(const OrNode* op)。
visit_reduce_(op:Reduce) → PrimExpr
访问 Reduce。
用户可以自定义该函数,在 C++端覆盖 VisitReduce_(const ReduceNode* op)。
visit_cast_(op:Cast) → PrimExpr
访问 Cast。
用户可以自定义该函数,在 C++端覆盖 VisitCast_(const CastNode* op)。
visit_not_(op:Not) → PrimExpr
不访问。
用户可以自定义该函数,在 C++端覆盖 VisitNot_(const NotNode* op)。
visit_select_(op:Select) → PrimExpr
访问选择。
用户可以自定义该函数,在 C++端覆盖 VisitSelect_(const SelectNode* op)。
visit_ramp_(op:Ramp) → PrimExpr
参观 Ramp。
用户可以自定义该函数,在 C++端覆盖 VisitRamp_(const RampNode* op)。
visit_broadcast_(op:Broadcast) → PrimExpr
访问广播。
用户可以自定义该函数,在 C++端覆盖 VisitBroadcast_(const BroadcastNode* op)。
visit_shuffle_(op:Shuffle) → PrimExpr
访问 Shuffle。
用户可以自定义该函数,在 C++端覆盖 VisitShuffle_(const ShuffleNode* op)。
visit_int_imm_(op:IntImm) → PrimExpr
访问 IntImm。
用户可以自定义该函数,在 C++端覆盖 VisitIntImm_(const IntImmNode* op)。
visit_float_imm_(op:FloatImm) → PrimExpr
访问 FloatImm。
用户可以自定义该函数,在 C++端覆盖 VisitFloatImm_(const FloatImmNode* op)。
visit_string_imm_(op:StringImm) → PrimExpr
访问 StringImm。
用户可以自定义该函数,在 C++端覆盖 VisitStringImm_(const StringImmNode* op)。