张量程序抽象
在深入了解 TensorIR 之前,我们先介绍什么是原始张量函数(primitive tensor function)。原始张量函数指的是代表单个「计算单元」的函数。例如,一个卷积操作可以是一个原始张量函数,一个融合了卷积与 relu 的操作也可以是一个原始张量函数。通常,原始张量函数的抽象包含以下几个要素:多维缓冲区(multi-dimensional buffers)、驱动张量计算的循环嵌套(loop nests)以及计算语句本身。
from tvm.script import tir as T
@T.prim_func
def main(
A: T.Buffer((128,), "float32"),
B: T.Buffer((128,), "float32"),
C: T.Buffer((128,), "float32"),
) -> None:
for i in range(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A[vi] + B[vi]
张量程序的关键元素
上面展示的原始张量函数计算了两个向量的逐元素加法。这个函数:
- 接收三个多维缓冲区作为参数,并生成一个多维缓冲区作为输出。
- 包含一个简单的循环嵌套
i
来驱动计算过程。 - 包含一个计算语句,用于计算两个向量的逐元素加和。
TensorIR 中的额外结构
需要注意的是,我们无法对程序执行任意的变换,因为某些计算依赖于循环的顺序。幸运的是,我们关注的大多数原始张量函数具有良好的性质,例如循环迭代之间相互独立。例如,前面的程序中包含了块(block)和迭代(iteration)的注解信息:
- 块注解
with T.block("C")
表示该块是调度中指定的基本计算单元。一个 block 可以只包含一条计算语句,也可以包含带有循环的多条语句,甚至是一些无法透视的内在指令(如 Tensor Core 指令)。 - 迭代注解
T.axis.spatial
表示变量vi
映射到i
,并且所有迭代是独立的。
虽然这些信息对于执行特定程序并非关键,但在变换程序时却非常有用。因此,只要我们访问了索引从 0 到 128 的所有元素,就可以放心地并行化或重排与 vi
相关的循环。