代码非常简单,相关代码都有注释,关于representative_dataset的作用可以参考这篇文章:Tensor Flow量化里representative_dataset参数是什么意思?_17岁boy的博客-CSDN博客
import tensorflow as tf
import io
import PIL
import numpy as np
def rep():
#需要是验证集的数据源
record_iterator = tf.python_io.tf_record_iterator(path='/home/zhihao/models/research/slim/ci_data/cifar10_train.tfrecord')
count = 0
#将图像从protobu取出来量化成数组
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
#这里是你存放图像数据的消息协议名
image_stream = io.BytesIO(example.features.feature['image/encoded'].bytes_list.value[0])
image = PIL.Image.open(image_stream)
#这里将它固定量化成96x96的数组大小,这样方便优化
image = image.resize((96,96))
#量化,L=灰度图,1个bit表示三个像素点
image = image.convert('L')
array = np.array(image)
array = np.expand_dims(array,axis=2)
array = np.expand_dims(array,axis=0)
array = ((array / 127.5) - 1.0).astype(np.float32)
yield([array])
count += 1
#最大量化三百张
if count > 300:
break
#你的PB文件,这个文件要是包含神经网络权重的PB文件
converter = tf.lite.TFLiteConverter.from_frozen_graph('/home/zhihao/work/freezed_cifarnet.pb',['input'],['MobilenetV1/Predictions/Reshape_1'])
converter.inference_input_type = tf.lite.constants.INT8
converter.inference_output_type = tf.lite.constants.INT8
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = rep
#量化并保存
tflite_quant_model = converter.convert()
open("test.tflite","wb").write(tflite_quant_model)