Relay IR 简介
本文介绍 Relay IR—— 第二代 NNVM。学习本节内容,需要具备一定的编程背景,以及熟悉计算图表示的深度学习框架。
学习本节内容,你将了解:
- 支持传统的数据流式编程和转换。
- 支持函数式作用域、let-binding,并使其成为功能齐全的可微分语言。
- 允许用户混合两种编程风格。
使用 Relay 构建计算图
传统的深度学习框架使用计算图作为中间表示(IR)。计算图(或数据流图)是表示计算的有向无环图(DAG)。尽管数据流图由于缺乏控制流,而受限于能表达的计算,但其易用性极大简化了实现自动微分和编译异构执行环境(例如,在专用硬件上执行部分图)。
您可以使用 Relay 构建计算(数据流)图。具体来说,上面的代码展示了如何构造一个简单的两节点图。您会发现该示例的语法与现有的计算图 IR(如 NNVMv1)没有太大区别,唯一的区别在于术语:
- 现有框架通常使用计算图和子图
- Relay 使用函数,例如–
fn (%x)
,表示计算图
每个数据流节点(dataflow node)都是 Relay 中的一个 CallNode。Relay Python DSL 允许快速构建数据流图(dataflow graph)。以上代码强调显式地构造了一个 Add 节点,其两个输入点都指向 %1
。深度学习框架评估上述程序时,会按照拓扑顺序计算节点,%1
只会计算一次。
尽管这种现象对于深度学习框架的构建者来说很常见,但对于 PL 研究人员来说可能并非如此。如果实现一个简单的 visitor 来打印结果,并将结果视为嵌套调用表达式,它会变成 log(%x) + log(%x)
。
当 DAG 中存在共享节点时,这种歧义是由对程序语义的不同解释引起的。在正常的函数式编程 IR 中,嵌套表达式被视为表达式树,其没有考虑到 %1
实际上在 %2
中重复使用了两次这一事实。
Relay IR 关注了这种差异。通常,深度学习框架用户以这种方式构建计算图,经常发生 DAG 节点重用。因此,当以文本格式打印出 Relay 程序时,每行打印一个 CallNode,并为每个 CallNode 分配一个临时 id (%1, %2)
,以便在程序的后面部分可以引用每个公共节点。
模块:支持多个函数(计算图)
到目前为止,已经介绍了如何将数据流图构建为函数。有人自然会问:我们能不能支持多种功能,让它们互相调用?Relay 允许将多个函数组合在一个模块中;下面的代码展示了一个函数调用另一个函数的示例。
def @muladd(%x, %y, %z) {
%1 = mul(%x, %y)
%2 = add(%1, %z)
%2
}
def @myfunc(%x) {
%1 = @muladd(%x, 1, 2)
%2 = @muladd(%1, 2, 3)
%2
}
模块可以看作是一个 Map<GlobalVar, Function>
。这里 GlobalVar 只是一个 id,用于表示模块中的函数。 @muladd
和 @myfunc
在上面的例子中是 GlobalVars。
当一个 CallNode 用于调用另一个函数时,对应的 GlobalVar 存储在 CallNode 的 op 字段中。它包含一个间接级别——要使用相应的 GlobalVar 从模块中查找被调用函数的主体。
在这种特殊情况下,也可以直接将 Function 的引用作为 op 存储在 CallNode 中。那么,为什么要引入 GlobalVar 呢?主要原因是 GlobalVar 将定义/声明解耦,并启用函数的递归和延迟声明。
def @myfunc(%x) {
%1 = equal(%x, 1)
if (%1) {
%x
} else {
%2 = sub(%x, 1)
%3 = @myfunc(%2)
%4 = add(%3, %3)
%4
}
}