目的
解析flatbuffer格式的tflite文件,转成可读的python dict格式
方法
#/tensorflow/lite/tools/visualize.py
import re
from tensorflow.lite.python import schema_py_generated as schema_fb
def BuiltinCodeToName(code):
"""Converts a builtin op code enum to a readable name."""
for name, value in schema_fb.BuiltinOperator.__dict__.items():
if value == code:
return name
return None
def CamelCaseToSnakeCase(camel_case_input):
"""Converts an identifier in CamelCase to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, preserve_as_numpy):
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
return fb
elif hasattr(fb, "__dict__"):
result = {}
for attribute_name in dir(fb):
attribute = fb.__getattribute__(attribute_name)
if not callable(attribute) and attribute_name[0] != "_":
snake_name = CamelCaseToSnakeCase(attribute_name)
preserve = True if attribute_name == "buffers" else preserve_as_numpy
result[snake_name] = FlatbufferToDict(attribute, preserve)
return result
elif isinstance(fb, np.ndarray):
return fb if preserve_as_numpy else fb.tolist()
elif hasattr(fb, "__len__"):
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
else:
return fb
def CreateDictFromFlatbuffer(buffer_data):
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)
return FlatbufferToDict(model, preserve_as_numpy=False)
转换
# Read the model.
with open('xxx.tflite', 'rb') as f:
model_buffer = f.read()
data = CreateDictFromFlatbuffer(model_buffer)
op_codes = data['operator_codes'] #支持/注册的op
subg = data['subgraphs'][0] #模型结构描述,具体的op构成
tensors = subg['tensors'] #tensor描述, 主要有layer参数、权重
for layer in subg['operators']:
#layer name
op_idx = layer['opcode_index']
op_code = op_codes[op_idx]['builtin_code']
layer_name = BuiltinCodeToName(op_code)
#layer param
layer_param = layer['builtin_options']
#layer input/output idx
input_tensor_idx = layer['inputs']
output_tensor_idx = layer['outputs']
#input
input_idx = input_tensor_idx[0]
#filter weight
weight_idx = input_tensor_idx[1]
weight = interpreter.get_tensor(weight_idx)
filters = tensors[weight_idx]['shape'][0] #卷积核尺寸
#filter bias
bias_idx = input_tensor_idx[2]