跳到主要内容
版本:0.12.0

编译 Keras 模型

备注

注意:单击 此处 下载完整的示例代码

作者Yuwei Hu

本文介绍如何用 Relay 部署 Keras 模型。

首先安装 Keras 和 TensorFlow,可通过 pip 快速安装:

pip install -U keras --user
pip install -U tensorflow --user

或参考官网:https://keras.io/#installation

import tvm
from tvm import te
import tvm.relay as relay
from tvm.contrib.download import download_testdata
import keras
import tensorflow as tf
import numpy as np

加载预训练的 Keras 模型

加载 Keras 提供的预训练 resnet-50 分类模型:

if tuple(keras.__version__.split(".")) < ("2", "4", "0"):
weights_url = "".join(
[
"https://github.com/fchollet/deep-learning-models/releases/",
"download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
]
)
weights_file = "resnet50_keras_old.h5"
else:
weights_url = "".join(
[
" https://storage.googleapis.com/tensorflow/keras-applications/",
"resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
]
)
weights_file = "resnet50_keras_new.h5"

weights_path = download_testdata(weights_url, weights_file, module="keras")
keras_resnet50 = tf.keras.applications.resnet50.ResNet50(
include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
)
keras_resnet50.load_weights(weights_path)

加载测试图像

这里使用的还是先前猫咪的图像:

from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras.applications.resnet50 import preprocess_input

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
plt.imshow(img)
plt.show()
# 预处理输入
data = np.array(img)[np.newaxis, :].astype("float32")
data = preprocess_input(data