0
点赞
收藏
分享

微信扫一扫

python解析tflite模型文件

嚯霍嚯 2022-03-30 阅读 116

目的

解析flatbuffer格式的tflite文件,转成可读的python dict格式

方法

        

#/tensorflow/lite/tools/visualize.py
import re
from tensorflow.lite.python import schema_py_generated as schema_fb
 
def BuiltinCodeToName(code):
    """Converts a builtin op code enum to a readable name."""
    for name, value in schema_fb.BuiltinOperator.__dict__.items():
        if value == code:
            return name
    return None
def CamelCaseToSnakeCase(camel_case_input):
    """Converts an identifier in CamelCase to snake_case."""
    s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, preserve_as_numpy):
    if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
        return fb
    elif hasattr(fb, "__dict__"):
        result = {}
        for attribute_name in dir(fb):
            attribute = fb.__getattribute__(attribute_name)
            if not callable(attribute) and attribute_name[0] != "_":
                snake_name = CamelCaseToSnakeCase(attribute_name)
                preserve = True if attribute_name == "buffers" else preserve_as_numpy
                result[snake_name] = FlatbufferToDict(attribute, preserve)
        return result
    elif isinstance(fb, np.ndarray):
        return fb if preserve_as_numpy else fb.tolist()
    elif hasattr(fb, "__len__"):
        return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
    else:
        return fb
def CreateDictFromFlatbuffer(buffer_data):
    model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
    model = schema_fb.ModelT.InitFromObj(model_obj)
    return FlatbufferToDict(model, preserve_as_numpy=False)

转换

# Read the model.
with open('xxx.tflite', 'rb') as f:
    model_buffer = f.read()


data = CreateDictFromFlatbuffer(model_buffer)
op_codes = data['operator_codes']  #支持/注册的op
subg = data['subgraphs'][0] #模型结构描述,具体的op构成
tensors = subg['tensors'] #tensor描述, 主要有layer参数、权重


for layer in subg['operators']:
    #layer name
    op_idx = layer['opcode_index']
    op_code = op_codes[op_idx]['builtin_code']
    layer_name = BuiltinCodeToName(op_code) 
    
    #layer param
    layer_param = layer['builtin_options']

    #layer input/output idx
    input_tensor_idx = layer['inputs']
    output_tensor_idx = layer['outputs']
    
    #input
    input_idx = input_tensor_idx[0]
    
    #filter weight
    weight_idx = input_tensor_idx[1]
    weight = interpreter.get_tensor(weight_idx)
    filters = tensors[weight_idx]['shape'][0] #卷积核尺寸

    #filter bias
    bias_idx = input_tensor_idx[2]

举报

相关推荐

0 条评论