一、前言
TVM是通过Relay进行模型计算图IRModule的构建,并且可以使用不同的优化策略(即Pass)对IRModule进行优化,最终编译成特定后端可执行的代码(codegen)。在整个过程中,可视化对IRModule的优化调试非常有用,因此TVM提供了relay_viz的模块接口用于IRModule的可视化。本文将介绍relay_viz的使用方法并简单分析其实现的代码流程,最后通过自定义Plotter将IRModule转换为HDF5格式的文件,用于Netron的显示。
二、IRModule可视化
TVM提供了可视化Relay的模块接口,目前(v0.11.dev0)内置支持了基于graphviz与终端显示两种可视化输出方式,下面是graphviz方式的使用示例:
from tvm.contrib import relay_viz
from tvm.relay.testing import mlp
# 获取vgg网络的IRModule和param
mod, param = mlp.get_workload(batch_size=1, num_classes=10)
# graphviz属性
graph_attr = {"color": "red"}
node_attr = {"color": "blue"}
edge_attr = {"color": "black"}
def get_node_attr(node):
# 将dense节点填充绿色矩形
if "nn.dense" in node.type_name:
return {"fillcolor": "green",
"style": "filled"}
# 将axis=-1的softmax节点填充橙色
if "nn.softmax" in node.type_name and "axis: -1" in node.detail:
return {"fillcolor": "orange",
"style": "filled"}
# 设置Var节点为椭圆形
if "Var" in node.type_name:
return {"shape": "ellipse"}
return {"shape": "box"}
# 创建一下plotter
dot_plotter = relay_viz.DotPlotter(
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
get_node_attr=get_node_attr)
viz = relay_viz.RelayVisualizer(
mod,
relay_param=param,
plotter=dot_plotter, # 传入定义的plotter
parser=relay_viz.DotVizParser())
# 渲染生成 pdf
viz.render("mlp")
执行完成后会在当前目录下生成mlp.pdf(模型较大时需要等一会才能看到正确的内容):
三、代码分析
模块中的类图关系如下所示:
代码执行流程大致如下:
(1)RelayVisualizer初始化时调用relay.analysis模块提供的函数post_order_visit(relay_mod[name], traverse_expr),对IRModule中的每个节点进行遍历,并执行traverse_expr()回调函数,为每个节点生成一个唯一的节点id,记录到node_to_id字典中;
(2)通过传入的plotter创建可视化图对象self._plotter.create_graph(name),把node_to_id中的节点与id添加到图中self._add_nodes(graph, node_to_id);
(3)self._add_nodes(graph, node_to_id)遍历每个节点,并将节点传入parser中进行解析,self._parser.get_node_edges(node, self._relay_param, node_to_id)得到可视化图对应的节点与边集合,然后通过graph.node(viz_node)及graph.edge(edge);将这些集合添加进可视化图中;
(4)使用者可以自定义不同的Plotter,不同的Plotter对应不同的VizGraph结构,目前内置实现了TermPlotter与DotPlotter,其中DotPlotter依赖于graphviz包,用于生成渲染与直观的图结构,而TermPlotter则实现了终端输出的文本式的图结构;
(5)为了适应Plotter的需求,VizParser也可以自定义。相应地,模块实现了TermVizParser与DotVizParser=DefaultVizParser。VizParser中的get_node_edges()会根据不同的节点类型进行相应的处理,比如DotVizParser就实现了function,call,var,tuple及constant节点类型的处理。
四、Netron显示Relay
本节将演示如何自定义Plotter,实现Relay的IRModule到HDF5格式的转换,从而适配Netron。
(1)HDF5格式
HDF5(Hierarchical Data Format)是一种常见的跨平台数据储存文件格式,常以.h5或者.hdf5为后缀名,内容主要由两种元素组成:
- Groups:类似于文件夹,用于对每种数据进行分类归置;
- Datasets:类似于NumPy中的数组array,用于存储实际数据。
每个dataset可以分成两部分: 原始数据(raw data values)和元数据(metadata)。
+-- Dataset
| +-- (Raw) Data Values 数据内容
| +-- Metadata
| | +-- Dataspace 原始数据的秩 (Rank) 和维度 (dimension)
| | +-- Datatype 数据类型
| | +-- Properties 分块储存以及压缩情况
| | +-- Attributes 其他自定义属性
(2)Netron需要的HDF5内容
要使HDF5能够在Netron上渲染,在内容上需要符合一定的要求,我们可以通过打开由Keras保存的模型文件内容进行分析:
sudo apt-get install hdf5-tools
h5dump -A keras_model.h
需要关注的两个重要属性是model_config与model_weights中的layer_names:
HDF5 "keras_model.h5" {
GROUP "/" {
...
ATTRIBUTE "model_config" {
DATATYPE H5T_STRING {...}
DATASPACE SCALAR
DATA {
(0): "{模型描述Json}"
}
}
GROUP "model_weights" {
...
ATTRIBUTE "layer_names" {
DATATYPE H5T_STRING {...}
DATASPACE SIMPLE { ( 15 ) / ( 15 ) }
DATA {
(0): "input_1(3)自定义Plotter0import json
import numpy as np
from typing import Dict
from tvm.contrib.relay_viz.interface import (
DefaultVizParser, Plotter,
VizEdge, VizGraph, VizNode,
)
try:
import h5py
HDF5_OBJECT_HEADER_LIMIT = 64512
except ImportError:
# add "from None" to silence
# "During handling of the above exception, another exception occurred"
raise ImportError(
"The h5py package is required. "
"Please install it first. For example, pip3 install h5py"
) from None
Hdf5VizParser = DefaultVizParser
class Hdf5Node:
def __init__(self, viz_node: VizNode):
self.name = viz_node.type_name + '_' + viz_node.identity
self.type = viz_node.type_name
self.params = self._detail_to_params(viz_node.detail)
def _detail_to_params(self, detail: str) -> Dict:
if len(detail) == 0:
return {}
ds = detail.split("\n")
params = {}
for p in ds:
k, v = p.split(":")
params[k] = v
return params
class Hdf5Graph(VizGraph):
"""Hdf5 graph for relay IR.
Parameters
----------
name: str
name of this graph.
"""
def __init__(
self,
name: str
):
self._name = name
self._graph = {}
self._id_to_hf_node = {}
def node(self, viz_node: VizNode) -> None:
"""Add a node to the underlying graph."""
if viz_node.identity not in self._graph:
# Add the node into the graph.
self._graph[viz_node.identity] = []
node = Hdf5Node(viz_node)
self._id_to_hf_node[viz_node.identity] = node
def edge(self, viz_edge: VizEdge) -> None:
"""Add an edge to the underlying graph."""
if viz_edge.end in self._graph:
self._graph[viz_edge.end].append(viz_edge.start)
else:
self._graph[viz_edge.end] = [viz_edge.start]
def get_layers(self):
layers = []
for id, in_ids in self._graph.items():
layer = {}
hf_node = self._id_to_hf_node[id]
layer['name'] = hf_node.name
layer['class_name'] = hf_node.type
layer['inbound_nodes'] = []
dtype_hint = ""
for in_id in in_ids:
in_hf_node = self._id_to_hf_node[in_id]
item = [in_hf_node.name, 0, 0, {}]
layer['inbound_nodes'].append(item)
if 'dtype' in in_hf_node.params.keys():
dtype_hint = in_hf_node.params['dtype']
layer['config'] = {'name': hf_node.params['name_hint'] if 'name_hint' in hf_node.params.keys() else hf_node.name}
for k, v in hf_node.params.items():
if 'out_dtype' in k and v and v == ' ':
layer['config']['dtype'] = dtype_hint
continue
if 'name_hint' in k or not v or len(v) == 0:
continue
layer['config'][k] = v
layers.append(layer)
return layers
class Hdf5Plotter(Plotter):
"""Hdf5 graph plotter"""
def __init__(self):
self._name_to_graph = {}
def _save_attr_to_group(self, group, name, data):
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
# Expecting this to never be true.
if bad_attributes:
raise RuntimeError('The following attributes cannot be saved to HDF5 '
'file because they are larger than %d bytes: %s' %
(HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes)))
data_npy = np.asarray(data)
num_chunks = 1
chunked_data = np.array_split(data_npy, num_chunks)
# This will never loop forever thanks to the test above.
while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
num_chunks += 1
chunked_data = np.array_split(data_npy, num_chunks)
if num_chunks > 1:
for chunk_id, chunk_data in enumerate(chunked_data):
group.attrs['%s%d' % (name, chunk_id)] = chunk_data
else:
group.attrs[name] = data
def create_graph(self, name):
self._name_to_graph[name] = Hdf5Graph(name)
return self._name_to_graph[name]
def render(self, filename: str = None):
for name in self._name_to_graph:
if filename is None:
filename = name
f = h5py.File(filename + '.h5', mode='w')
g = f.create_group('model_weights')
g.attrs['backend'] = 'tvm.relay'.encode('utf8')
g.attrs['tvm_version'] = "0.11".encode('utf8')
layers = self._name_to_graph[name].get_layers()
mod_cfg = {"class_name": "Model", "config": {"name": "model", 'layers': layers}}
f.attrs['model_config'] = json.dumps(mod_cfg).encode('utf8')
self._save_attr_to_group(g, 'layer_names', [layer['name'].encode('utf8') for layer in layers])
f.close()
0(4)Netron的显示结果0from tvm.relay.testing import mlp
from tvm.contrib import relay_viz
from relay_viz_hdf5 import Hdf5Plotter, Hdf5VizParser
mod, param = mlp.get_workload(batch_size=1, num_classes=10)
print("mod:{}\n\n".format(mod))
viz = relay_viz.RelayVisualizer(
mod,
relay_param=param,
plotter=Hdf5Plotter(),
parser=Hdf5VizParser()
)
viz.render('mlp')
0五、总结
0000",
(1): "conv2d000000000",
...
}
}
...
model_config保存了模型的结构与参数配置,整体结构为:
其中input_layers与output_layers是相同的结构,是数组的数组:
而layers是字典的数组,包括名称(name),类型名(class_name),输入节点列表(inbound_nodes)以及参数配置(config)
而model_weights保存了模型的数据,其中layer_names属性保存了每一层的名称。由于我们仅仅是为了显示IRModule的结构,所以只要将IRModule转换成model_config与layer_names所需的信息即可。
新建一个relay_viz_hdf5.py文件,Hdf5VizParser使用默认的DefaultVizParser解析器,实现自定义的Hdf5Plotter
测试代码如下:
生成的mlp.h5在Netron上显示的结果:
本文介绍了TVM可视化模块relay_viz的使用方法,并简单分析了代码的实现流程,最后通过自定义Plotter将IRModule转换为HDF5格式的文件,实现了Netron显示IRModule的功能,代码地址在https://github.com/Oreobird/relay_to_hdf5,感兴趣的读者可以参阅。