跳到主要内容

TensorFlow 前端

TensorFlow 前端有助于将 TensorFlow 模型导入 TVM。

支持的版本:

  • 1.12 及以下

测试模型:

  • Inception(V1/V2/V3/V4)
  • Resnet(全部)
  • Mobilenet(V1/V2 全部)
  • Vgg(16/19)
  • BERT(基础/3 层)

为推理准备模型

删除不需要的节点

导出过程将删除许多不需要的推理节点,但还是会留下一些。需要手动删除的节点有:

将 None Dimensions 转换为常量

TVM 对动态张量 shape 的支持最少。为 None 的维度应替换为常量。例如,模型可以接受 shape 为 (None,20) 的输入。这应该转换为类似 (1,20) 的 shape。应相应地修改模型以确保这些 shape 在整个计算图中匹配。

导出

TensorFlow 前端需要一个冻结的 protobuf(.pb)或保存的模型作为输入。它目前不支持检查点(.ckpt)。TensorFlow 前端所需的 graphdef 可以从活动 session 中提取,或者使用 TFParser 辅助类。

导出模型时应进行一些转换,以准备模型进行推理。设置 add_shapes=True 也很重要,因为这会将每个节点的输出 shape 嵌入到图中。这是一个将模型导出为给定会话的 protobuf 的函数:

import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph

def export_pb(session):
with tf.gfile.GFile("myexportedmodel.pb", "wb") as f:
inputs = ["myinput1", "myinput2"] # 替换为你的输入名称

outputs = ["myoutput1"] # 替换为你的输出名称
graph_def = session.graph.as_graph_def(add_shapes=True)
graph_def = tf.graph.util.convert_variables_to_constants(session, graph_def, outputs)
graph_def = TransformGraph(
graph_def,
inputs,
outputs,
[
"remove_nodes(op=Identity, op=CheckNumerics, op=StopGradient)",
"sort_by_execution_order", # 每次转换后按执行顺序排序,以确保正确的节点排序
"remove_attribute(attribute_name=_XlaSeparateCompiledGradients)",
"remove_attribute(attribute_name=_XlaCompile)",
"remove_attribute(attribute_name=_XlaScope)",
"sort_by_execution_order",
"remove_device",
"sort_by_execution_order",
"fold_batch_norms",
"sort_by_execution_order",
"fold_old_batch_norms",
"sort_by_execution_order"
]
)
f.write(graph_def.SerializeToString())

另一种方法是 导出并冻结计算图

导入模型

显式 shape:

为确保可以在整个计算图中了解 shape,将 shape 参数传递给 from_tensorflow。该字典将输入名称映射到输入 shape。有关示例,参阅这些 测试用例

数据布局

大多数 TensorFlow 模型都是使用 NHWC 布局发布的。NCHW 布局通常提供更好的性能,尤其是在 GPU 上。TensorFlow 前端可以通过将参数 layout='NCHW' 传递给 from_tensorflow 来自动转换模型的数据布局。

最佳实践

  • 使用静态张量 shape,而非动态 shape(删除「None」尺寸)。
  • 使用静态 RNN,而非动态 RNN,因为尚不支持 TensorArray

支持的算子

  • Abs
  • Add
  • AddN
  • All
  • Any
  • ArgMax
  • ArgMin
  • AvgPool
  • BatchMatMul
  • BatchMatMulV2
  • BatchNormWithGlobalNormalization
  • BatchToSpaceND
  • BiasAdd
  • BroadcastTo
  • Cast
  • Ceil
  • CheckNumerics
  • ClipByValue
  • Concat
  • ConcatV2
  • Conv2D
  • Cos
  • Tan
  • CropAndResize
  • DecodeJpeg
  • DepthwiseConv2dNative
  • DepthToSpace
  • Dilation2D
  • Equal
  • Elu
  • Enter
  • Erf
  • Exit
  • Exp
  • ExpandDims
  • Fill
  • Floor
  • FloorDiv
  • FloorMod
  • FusedBatchNorm
  • FusedBatchNormV2
  • Gather
  • GatherNd
  • GatherV2
  • Greater
  • GreaterEqual
  • Identity
  • IsFinite
  • IsInf
  • IsNan
  • LeakyRelu
  • LeftShift
  • Less
  • LessEqual
  • Log
  • Log1p
  • LoopCond
  • LogicalAnd
  • LogicalOr
  • LogicalNot
  • LogSoftmax
  • LRN
  • LSTMBlockCell
  • MatMul
  • Max
  • MaxPool
  • Maximum
  • Mean
  • Merge
  • Min
  • Minimum
  • MirrorPad
  • Mod
  • Mul
  • Neg
  • NextIteration
  • NotEqual
  • OneHot
  • Pack
  • Pad
  • PadV2
  • Pow
  • Prod
  • Range
  • Rank
  • RealDiv
  • Relu
  • Relu6
  • Reshape
  • ResizeBilinear
  • ResizeBicubic
  • ResizeNearestNeighbor
  • ReverseV2
  • RightShift
  • Round
  • Rsqrt
  • Select
  • Selu
  • Shape
  • Sigmoid
  • Sign
  • Sin
  • Size
  • Slice
  • Softmax
  • Softplus
  • SpaceToBatchND
  • SpaceToDepth,
  • Split
  • SplitV
  • Sqrt
  • Square
  • SquareDifference
  • Squeeze
  • StridedSlice
  • Sub
  • Sum
  • Switch
  • Tanh
  • TensorArrayV3
  • TensorArrayScatterV3
  • TensorArrayGatherV3
  • TensorArraySizeV3
  • TensorArrayWriteV3
  • TensorArrayReadV3
  • TensorArraySplitV3
  • TensorArrayConcatV3
  • Tile
  • TopKV2
  • Transpose
  • TruncateMod
  • Unpack
  • UnravelIndex
  • Where
  • ZerosLike