InferBound Pass
InferBound pass 在 normalize 之后、ScheduleOps build_module.py 之前运行。InferBound 的主要工作是创建 bounds map,为程序中的每个 IterVar 指定一个 Range。接下来这些 bounds 会传递给 ScheduleOps,用于设置 For 循环的范围,参阅 MakeLoopNest,以及设置分配缓冲区的大小(BuildRealize)以及其他用途。
InferBound 的输出是从 IterVar 到 Range 的映射:
Map<IterVar, Range> InferBound(const Schedule& sch);
回顾 Range 和 IterVar 类:
namespace HalideIR {
namespace IR {
class RangeNode : public Node {
public:
Expr min;
Expr extent;
// 剩余部分省略
};
}}
namespace tvm {
class IterVarNode : public Node {
public:
Range dom;
Var var;
// 剩余部分省略
};
}
注意,IterVarNode 还包含一个 Range dom
。这个 dom
的值是否有意义,取决于 IterVar 的创建时间。例如,调用 tvm.compute
时,会为每个 axis 和 reduce axis 创建一个 IterVar ,其中 dom 等于调用 tvm.compute
时提供的 shape。
另一方面,调用 tvm.split
时,会为内轴和外轴 创建 IterVars,但这些 IterVars 没有被赋予有意义的 dom
值。
在任何情况下,IterVar 的 dom
成员在 InferBound 期间都不会被修改。但 IterVar 的 dom
成员有时用作 Range InferBound 计算的默认值。
为了理解 InferBound pass,我们先来看一下 TVM 代码库概念。
InferBound 接收一个参数,即 Schedule。这个 schedule 对象及其成员包含正在编译的程序的所有信息。
TVM schedule 由 stage 组成。每个 stage 只有一个 Operation,例如 ComputeOp 或 TensorComputeOp。每个 Operation 都有一个 root_iter_vars 列表,在 ComputeOp 的情况下,它由 axis IterVar 和 reduce axis IterVar 组成。
每个 Operation 还包含许多其他 IterVar,它们通过 Operation 的 IterVarRelations 列表相关联。每个 IterVarRelation 代表 schedule 中的 split、fuse 或 rebase。例如,在 split 的情况下,IterVarRelation 指定被拆分的父级 IterVar,以及两个子级 IterVar:内部和外部。
namespace tvm {
class ScheduleNode : public Node {
public:
Array<Operation> outputs;
Array<Stage> stages;
Map<Operation, Stage> stage_map;
// 剩余部分省略
};
class StageNode : public Node {
public:
Operation op;
Operation origin_op;
Array<IterVar> all_iter_vars;
Array<IterVar> leaf_iter_vars;
Array<IterVarRelation> relations;
// 剩余部分省略
};
class OperationNode : public Node {
public:
virtual Array<IterVar> root_iter_vars();
virtual Array<Tensor> InputTensors();
// 剩余部分省略
};
class ComputeOpNode : public OperationNode {
public:
Array<IterVar> axis;
Array<IterVar> reduce_axis;
Array<Expr> body;
Array<IterVar> root_iter_vars();
// 剩余部分省略
};
}
在 TVM 的 context 中,张量表示操作的输出。
class TensorNode : public Node {
public:
// 源操作,可以是 None
// 这个 Tensor 是这个 op 输出的
Operation op;
// 源操作的输出索引
int value_index;
};
上面的 Operation 类声明中 ,可以看到每个 operation 还有一个 InputTensor 列表。因此,schedule 的各个 stage 形成了一个 DAG,其中每个 stage 都是图中的一个节点。若 Stage B 的 operation 有一个输入张量,其源操作是 Stage A 的 op,那么图中从 Stage A 到 Stage B 有一个 edge。简而言之,若 B 消耗了一个由 A 产生的张量,则从 A 到 B 会出现一个 edge。参见下图。这个计算图是在 InferBound 开始时调用 CreateReadGraph 创建的。
InferBound 使 pass 遍历计算图,每个 stage 访问一次。InferBound 从输出 stage 开始(即上图中的实心蓝色节点),然后向上移动(在边缘的相反方向上)。这是通过对计算图的节点执行反向拓扑排序来实现的。因此,当 InferBound 访问一个 stage 时,它的每个 consumer stage 都已经被访问过。
InferBound pass 如以下伪代码所示:
Map<IterVar, Range> InferBound(const Schedule& sch) {
Array<Operation> outputs = sch->get_outputs();
G = CreateGraph(outputs);
stage_list = sch->reverse_topological_sort(G);
Map<IterVar, Range> rmap;
for (Stage s in stage_list) {
InferRootBound(s, &rmap);
PassDownDomain(s, &rmap);
}
return rmap;
}
InferBound pass 有两个不是很明显的属性:
- InferBound 访问一个 stage 后,stage 中所有 IterVar 的范围都会在
rmap
中设置。 - 每个 IterVar 的 Range 只在
rmap
中设置一次后就不会再变了。
因此,仍然需要解释 InferBound 在访问 stage 时的主要作用。从上面的伪代码中可以看 出,InferBound 在每个 stage 调用了两个函数:InferRootBound 和 PassDownDomain。InferRootBound 的目的是设置 stage 每个 root_iter_var 的 Range(在 rmap
中)。(注意:InferRootBound 不设置任何其他 IterVar 的 Range,只设置属于 root_iter_vars 的那些)。PassDownDomain 的目的是将此信息传播到 stage 的其余 IterVars。当 PassDownDomain 返回时,stage 的所有 IterVars 在 rmap
中都有已知的 Range。
文档的其余部分将深入探讨 InferRootBound 和 PassDownDomain 的详细信息。由于 PassDownDomain 描述起来更简单,因此首先介绍它。
IterVar Hyper-graph
如上所述,InferBound pass 遍历 stage 计算图。但是,在每个 stage 中都有另一个节点为 IterVars 的计算图。 InferRootBound 和 PassDownDomain 在这些 IterVar 计算图上传递消息。
回想一下,stage 的所有 IterVar 都由 IterVarRelations 关联。一个 stage 的 IterVarRelations 构成一个有向无环 hyper-graph,计算图中每个节点对应一个 IterVar,每条 hyper-edge 对应一个 IterVarRelation。也可以将这个 hyper-graph 表示为 DAG,如下图所示更易于可视化。
上图显示了一个 stage 的 IterVar hyper-graph。该 stage 有一个 root_iter_var i
,它已被拆分,生成的内轴 i.inner
已再次拆分。该 stage 的 leaf_iter_vars 为绿色图示:i.outer
、i.inner.outer
和 i.inner.inner
。
消息传递函数被命名为「PassUp」或「PassDown」,取决于消息是从 DAG 中的子代传递给其 父代(「PassUp」),还是从父代传递给其子代(「PassDown」)。例如,上图左侧的大箭头显示 PassDownDomain 从根 IterVar i
向其子 i.outer
和 i.inner
发送消息。
PassDownDomain
PassDownDomain 的作用是为 root_iter_vars 取 InferRootBound 产生的 Range,并设置 stage 中所有其他 IterVars 的 Range。
PassDownDomain 遍历 stage 的 IterVarRelations。IterVarRelation 有三种可能的类型:split、fuse 和 rebase。最有趣的案例(因为它还有改进空间)是表示 split 的 IterVarRelations。
根据父级 IterVar 的已知 Range,来设置 split 的内部 IterVar 和外部 IterVar 的 Range,如下:
rmap[split->inner] = Range::FromMinExtent(0, split->factor)
rmap[split->outer] = Range::FromMinExtent(0, DivCeil(rmap[split->parent]->extent, split->factor))
当 split->factor
没有平均划分父节点的范围时,就有机会收紧 InferBound 产生的边界。假设 parent 的范围是 20,split 因子是 16。那么在外部循环的第二次迭代中,内部循环只需要进行 4 次迭代,而非 16 次。如果 PassDownDomain 可以设置 split->inner
的范围为 min (split->factor, rmap[split->parent]->extent - (split->outer * split->factor))
,则内部变量的范围将根据正在执行的外部循环的迭代进行适当调整。
对于 Fuse 关系,根据已知的内外 IterVar 的 Range 设置 fuse 后的 IterVar 的 Range,如下:
rmap[fuse->fused] = Range::FromMinExtent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
InferRootBound
InferBound 调用 InferRootBound,然后在 stage 计算图中的每个 stage 调用 PassDownDomain。InferRootBound 的目的是设置 Stage 操作的每个 root_iter_var 的 Range。这些 Range 会用 PassDownDomain 传到 Stage 的其余 IterVars。注意,InferRootBound 不会设置任何其他 IterVar 的 Range,仅设置属于 Stage 的 root_iter_vars 的那些。
若 Stage 是输出 Stage 或占位符,InferRootBound 只需将 root_iter_var Range 设置为其默认值。root_iter_var 的默认 Range 取自 IterVar 的 dom
成员(参阅上面的 IterVarNode 类声明)。
否则,InferRootBound 将遍历 stage 的 consumer。为每个 consumer 的 IterVar 创建 IntSet,如下所示。
阶段 1)IntSet 为 consumer 的 leaf_iter_vars 初始化,并通过 PassUpDomain 传到 consumer 的 root_iter_vars(阶段 2)。这些 IntSet 用于创建 consumer stage(阶段 3)的输入张量的 TensorDom。最后,一旦所有 consumer 都处理完毕,InferRootBound 调用 GatherBound,根据 TensorDoms(阶段 4)设置 stage 的 root_iter_vars 的 Range。
这个过程看起来很复杂。原因之一是一个 stage 可以有多个 consumer。每个 consumer 都有不同的要求,且必须以某种方式整合。类似地,该 stage 可能会输出多个张量,并且每个 consumer 只使用这些张量的特定子集。此外,即使 consumer 使用特定的张量,它也可能不会使用张量的所有元素。
如上所述,consumer 可能只需要每个张量中的少量元素。consumer 可以看成是针对输出张量某些区域,向 stage 发出的请求。阶段 1-3 的工作是建立每个 consumer 所需的每个输出张量的区域。
IntSet
在 InferRootBound 期间,Range 被转换为 IntSet,并且在 IntSet 上执行消息传递。因此,了解 Range 和 IntSet 之间的区别很重要。「IntSet」这个名称表明它可以表示任意整数集,例如 A = 13。这肯定比 Range 更具表现力,Range 只表示一组连续的整数,例如 B = 12。
然而,目前 IntSet 只有三种类型:IntervalSets、StrideSets 和 ModularSets。与 Range 类似,IntervalSets 仅表示连续整数的集合。StrideSet 由基本 IntervalSet、步长列表和范围列表定义。StrideSet 未被使用,ModularSet 只用于前端。
因此,目前在 TVM 中并非所有的整数集合都可以用 IntSet 来表示。例如,上例中的集合 A 不能用 IntSet 表示。将来 IntSet 的功能可以扩展为处理更通用的整数集,而无需对 IntSet 的用户进行修改。
对于包含 compute_at 的 schedules而言,InferBound 更为复杂。因此首先针对不包含 compute_at 的 schedules解读InferBound。
阶段 1:为 consumer 的 leaf_iter_vars 初始化 IntSet
/*
* 输入: Map<IterVar, Range> rmap: 包含 consumer stage 的每个 IterVar 的 Range
* 输出: Map<IterVar, IntSet> up_state: 包含 consumer 的每个 leaf_iter_var 的 IntSet
*/
在阶段 1,根据 rmap
中 leaf_iter_vars 的 Range 创建每个 consumer 的 leaf_iter_vars 的 IntSet。consumer 已经被 InferBound 访问过,所以它所有的 IterVar 都知道 rmap
中的 Range。
有以下三种案例:
- 案例 1:leaf var 的 Range 范围为 1。这种情况下,leaf 的 up_state 只是一个点,等于 Range 的最小值。
- 案例 2:不需要释放。这种情况下,leaf 的 up_state 只是一个点,由 leaf var 本身定义。
- 案例 3:需要释放。这种情况下,leaf 的 Range 被简单地转换为 IntSet。
简单起见,假设 schedule 不包含线程轴。这种情况下,仅当 schedule 包含 compute_at 时,才和案例 2 相关。参阅 InferBound 与 compute_at 节来进一步获取更多信息。
阶段 2:将 IntSet 从 consumer 的 leaf 传到 consumer 的 root
/*
* Input: Map<IterVar, IntSet> up_state: consumer leaf -> IntSet
* Output: Map<IterVar, IntSet> dom_map: consumer root -> IntSet
*/
阶段 2 的目的是将 IntSet 信息从 consumer 的 leaf_iter_vars 传到 consumer 的 root_iter_vars。阶段 2 的结果是另一个映射 dom_map
,其中包含每个 consumer 的 root_iter_vars 的 IntSet。
阶段 2 首先调用 PassUpDomain,它访问 consumer stage 的 IterVarRelations。在 Split 关系的情况下,PassUpDomain 根据内部和外部 IntSet 设置父级 IterVar 的 up_state,如下所示:
- 案例 1:外部和内部 IterVar 的范围匹配它们的
up_state
域。在这种情况下,只需将父级的 Range 转换为 IntSet 即可设置父级的up_state
。 - 案例 2:否则,父级的
up_state
是相对于外部和内部的*up_state
通过评估outer*f + inner + rmap[parent]->min
来定义的。这里,TVM 没有使用s**plit 关系的因子,而是用*f = rmap[inner]->extent
。
仅当 schedule 包含 compute_at 时才需要案例 2。参阅下面的 InferBound 与 compute_at 节,进一步了解。
在 PassUpDomain 完成向 consumer 的所有 IterVars 传到 up_state 后,将创建一个从 root_iter_vars 到 IntSet 的新映射。如果 schedule 不包含 compute_at,则 root_iter_var iv 的 IntSet 由以下代码创建:
dom_map[iv->var.get()] = IntSet::range(up_state.at(iv).cover_range(iv->dom));
注意,若 schedule 不包含 compute_at,则实际上不需要阶段 1-2。dom_map 可以直接从 rmap 中的已知 Range 构建。Range 只需要转换为 IntSet,不会丢失信息。
阶段 3:将 IntSet 传到 consumer 的输入张量
/*
* Input: Map<IterVar, IntSet> dom_map: consumer root -> IntSet
* Output: Map<Tensor, TensorDom> tmap: output tensor -> vector<vector<IntSet> >
*/