Optimizing MobileDet for Mobile Deployments

Learn about the criticalities of effectively optimizing MobileDet object detectors for mobile deployments.
tflite
model-optimization
mobiledet
Published

September 29, 2020

This year researchers from the University of Wisconsin-Madison and Google published their work on MobileDet. MobileDet presents an architectural philosophy for designing object detectors specifically targeted toward running on mobile accelerators like DSP, EdgeTPU, and so on. MobileDet yields significant improvement over architectures MobileNetV2+SSDLite and MobileNetV3+SSDLite on the COCO object detection task with the same accelerated inference time. Long story cut short, if you are planning to use object detection models in mobile applications MobileDets may be an extremely good choice.

One fantastic thing about modern-day research is most of the time, the code and essential artifacts (like the trained models) are available publicly. MobileDet is no exception; the authors released their code and pre-trained models in TensorFlow Object Detection (TFOD) API. The model files come in three different variants -

Each of these variants includes the pre-trained checkpoints, a TensorFlow Lite (TFLite) compatible model graph, a TFLite model file, a configuration file, and a graph proto. The models were pre-trained on the COCO dataset.

In this post, I am going to be revisiting the TFLite conversion from the pre-trained model checkpoints along with some of the non-trivial things that come up during the process. It is basically an extension of Khanh LeViet and my findings we shared over this GitHub thread.

The code discussed throughout this post is available here as a Colab Notebook.

Important

If you want to train MobileDet models on your own dataset you may find these notebooks useful. They show you how to prepare the dataset, fine-tune a MobileDet model with the dataset, and optimize the fine-tuned model with TFLite.

Why yet another post on model conversion?

Fair question. After all, there are so many great examples and tutorials that show how to use the post-training quantization APIs in TFLite to perform the model conversion. MobileDet models in the TFOD API repository were trained in TensorFlow (TF) 1. If you ever wanted to use the latest TFLite converter to do the conversion, that may not be immediately approachable.

Besides, there are certain caveats to the EdgeTPU and DSP variants. They come in two precision formats - uint8 and float32. The models in uint8 precision were trained using quantization aware training (QAT) while the float32 models were not. During QAT fake quantization nodes get inserted into a model’s computation graph. So, the models trained using QAT usually require some extra care during the TFLite conversion process as we’ll see in a moment.

If we wanted to convert a single shot detector (SSD) based model to TFLite then we first need to generate a frozen graph first that is compatible with the TFLite operator set (as per these guides - TF1 and TF2). The TFOD API team provides stock scripts (TF1, TF2) for this. Both of these scripts add optimized postprocessing operations to the model graph. Now, these operations are not yet supported in int8 precision. So, if you ever wanted to convert these pre-trained checkpoints using full integer quantization, what would have been your approach?

By now, hopefully, I have been able to convince you that this post is not just about regular model conversion in TFLite. The situations we’ll be going through over the next sections may be helpful for your production TFLite models as well.

The hassle-free conversions

Before we build our way toward the fun stuff, let’s start with the conversions that won’t cost us a night’s sleep. Conversions based on dynamic-range and float16 quantization would come under this category.

Important

The EdgeTPU and DSP variants of MobileDet are meant to run on the respective hardware accelerators. These accelerators need a model to be in full integer precision. So converting the EdgeTPU and DSP variants with dynamic-range and float16 quantization does not have any practical usage.

So, for dynamic-range and float16 quantization based conversions, we will be using the CPU variant only. This variant is available here as ssd_mobiledet_cpu_coco. Once the model bundle is untar’d we get the following files -

├── model.ckpt-400000.data-00000-of-00001
├── model.ckpt-400000.index
├── model.ckpt-400000.meta
├── model.tflite
├── pipeline.config
├── tflite_graph.pb
└── tflite_graph.pbtxt

model.ckpt-* files are the pre-trained checkpoints on the COCO dataset. If you train a MobileDet object detection model on your own dataset, you will have your own model checkpoint files. The tflite_graph.pb file is a frozen inference graph that is compatible with the TFLite operator set, which was exported from the pre-trained model checkpoints. model.tflite file is a TFLite model that was converted from the tflite_graph.pb frozen graph.

In case if you ever train a MobileDet model on your dataset, here’s how you’d get the TFLite frozen graph file (based on this guide mentioned above) -

$ PIPELINE_CONFIG="checkpoint_name/pipeline.config"
$ CKPT_PREFIX="checkpoint_name/model.ckpt-400000"
$ OUTPUT_DIR="tflite_graph"
 
$ python models/research/object_detection/export_tflite_ssd_graph.py \
   --pipeline_config_path=$PIPELINE_CONFIG \
   --trained_checkpoint_prefix=$CKPT_PREFIX \
   --output_directory=$OUTPUT_DIR \
   --add_postprocessing_op=true

You can see a fully worked out example in the Colab Notebook mentioned above. If everything goes well, then you should have the frozen graph file exported in OUTPUT_DIR. Let’s now proceed to the TFLite model conversion part.

Here’s how the dynamic-range quantization would look like in TensorFlow 2 -

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=model_to_be_quantized,
    input_arrays=['normalized_input_image_tensor'],              
    output_arrays=['TFLite_Detection_PostProcess',
        'TFLite_Detection_PostProcess:1',
        'TFLite_Detection_PostProcess:2',
        'TFLite_Detection_PostProcess:3'],
   input_shapes={'normalized_input_image_tensor': [1, 320, 320, 3]}
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

A note about some of the parameters and their values from the above code listing -

  • model_to_be_quantized corresponds to the frozen graph file.

  • input_arrays and input_shapes are set accordingly with respect to the frozen graph file. As we can see in the figure below that these values have been set correctly.

  • output_arrays is set according to the instructions provided in this guide. Those operations represent four arrays: detection_boxes, detection_classes, detection_scores, and num_detections, usually a mandate for any object detector out there.

The rest of the parts in the code listing should be familiar to you if you already know about the typical post-training quantization process in TFLite. For float16 quantization, all the things would remain the same; we just need to add this line before calling convert() - converter.target_spec.supported_types = [tf.float16].

The dynamic-range quantized model is 4.3 MB in size and float16 one is 8.2 MB. Later, we will see how fast this model would run on actual mobile devices with and without different accelerators.

The trickier TFLite conversions for MobileDet

In this section, we will be dealing with the full integer quantization for the three different variants of MobileDet. Full integer quantization is usually more involved than the other quantization formats supported by TFLite.

Representative dataset

Our first step toward doing full integer quantization is preparing a representative dataset. It is required to calibrate the activation ranges so that the quantized model is able to retain the original model performance as much as possible. For the purpose of this post, I sampled 100 images from the COCO training dataset (train2014 split). In my experience, 100 samples as the representative dataset have always been sufficient. I have hosted these images here in case you are interested to use them.

The following code listing denotes a generator function that produces a preprocessed image to the TFLite converter -

rep_ds = tf.data.Dataset.list_files("train_samples/*.jpg")
HEIGHT, WIDTH = 320, 320
 
def representative_dataset_gen():
   for image_path in rep_ds:
       img = tf.io.read_file(image_path)
       img = tf.io.decode_image(img, channels=3)
       img = tf.image.convert_image_dtype(img, tf.float32)
       resized_img = tf.image.resize(img, (HEIGHT, WIDTH))
       resized_img = resized_img[tf.newaxis, :]
       yield [resized_img]

Note that these preprocessing steps should be in sync with the actual preprocessing steps that would apply before running inference with your TFLite model. In case if you are interested to know about more complex representative dataset generators you may find this notebook useful.

Also, note that dynamic-range and float16 quantization of the EdgeTPU and DSP variants don’t have much of practical usage. The next section is going to be solely about full integer quantization of these different variants and the nitty-gritty to take into consideration for the conversion process.

Dealing with fake quantization nodes during conversion

The figure below represents a portion of the uint8 EdgeTPU model computation graph. The nodes highlighted in red are inserted by the QAT mechanism. You would notice the same kind of nodes in the uint8 DSP model computation graph as well.

Now, these nodes have some important implications that we need to consider during the conversion process -

  • During QAT the activation ranges are already approximated i.e. QAT resembles post-training quantization during training and adjusts the activation ranges accordingly. So, we don’t need to provide a representative dataset for a full integer quantization based conversion.
  • These fake nodes are generally in integer precision. So, setting an optimization option (converter.optimizations) might lead to inconsistencies.
  • In order to convert the uint8 models with full integer quantization, we need to set the input and output data type of the TFLite models to integer precision (typically uint8 or int8). As per this documentation, we also need to specify the quantized_input_stats parameter during the conversion process. This is needed in order for the converted TFLite model to map the quantized input values to real values. More details are available here.

So, how do we realize all of these in code?

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
   graph_def_file=model_to_be_quantized,
   input_arrays=['normalized_input_image_tensor'],
   output_arrays=['TFLite_Detection_PostProcess',
       'TFLite_Detection_PostProcess:1',
       'TFLite_Detection_PostProcess:2',
       'TFLite_Detection_PostProcess:3'],
   input_shapes={'normalized_input_image_tensor': [1, 320, 320, 3]}
)
converter.inference_input_type = tf.uint8
converter.quantized_input_stats = {"normalized_input_image_tensor": (128, 128)}
tflite_model = converter.convert()

If you’re thinking this does not look all that gory compared to the above code listing - it does not have to be! The tooling should help you do these things seamlessly. But catching these details during your project development may not be trivial. Note that we don’t specify converter.inference_output_type. Hold your breath, we will come to this in a moment.

After successful execution, we get two full integer quantized models - EdgeTPU one is 4.2 MB and the DSP one is 7.0 MB.

Integer quantization for CPU variants and float32 precision models

The variants that don’t contain fake quantization nodes (CPU and all the models in float32 precision) have a relatively simpler conversion process. Recollect that the EdgeTPU and DSP variants come in two different precisions - uint8 and float32. For example, here’s how it would be for the float32 precision models -

converter.representative_dataset = representative_dataset_gen
converter.inference_input_type = tf.uint8
converter.optimizations = [tf.lite.Optimize.DEFAULT]

Note that we are specifying a representative dataset here because the float32 precision models weren’t trained using QAT. For the CPU variant model, the lines of code would slightly change -

converter.inference_input_type = tf.uint8
converter.quantized_input_stats = {"normalized_input_image_tensor": (128, 128)}
converter.optimizations = [tf.lite.Optimize.DEFAULT]

Honestly, I found this configuration by trial and error. I observed that if I specify a representative dataset then it hurts the predictions of the converted model. Also, I found out that specifying converter.quantized_input_stats helped improve the predictions of the converted model.

We don’t specify converter.inference_output_type in this case as well. Let’s get to it now.

Dealing with non-integer postprocessing ops during conversion

Remember that frozen graph exporter scripts provided by the TFOD API team add optimized postprocessing operations to the graph. These operations are not supported in integer precision yet. So, even if you wanted to specify converter.inference_output_type as tf.uint8 you’ll likely get the following error -

RuntimeError: Unsupported output type UINT8 for output tensor 'TFLite_Detection_PostProcess' of type FLOAT32.

This is why we did not set the converter.inference_output_type parameter.

This should resolve all the problems you may run into if you ever wanted to convert the MobileDet models offered by the TFOD API team. In the last two sections, we’ll see these converted models in action and how fast they can perform on respective hardware accelerators.

Show me some results

For the CPU variant model, its float16 quantized TFLite provided decent results -

On Colab, the inference time is about 92.36 ms for this particular model. I experimented with different threshold values for filtering out the weak predictions and a threshold of 0.3 yielded the best results. These results are pretty consistent across the several different models we talked about.

A major point to note here for the EdgeTPU and DSP variants, their converted counterparts would be much slower on Colab since they were specifically optimized for different hardware accelerators.

You are encouraged to play with the different converted models using the Colab Notebook mentioned above and see these results for yourself.

Model benchmarks

In this section, we’ll address the question - “So, how do I choose one among these many models?” Well, you could manually try them all out and see which performs the best on the runtime of your choice. But a more practical approach to this would be to first benchmark these models on a set of devices using the TFLite Benchmark Tool and then decide accordingly.

The following table provides a comprehensive summary of the important statistics about the runtime of different TFLite MobileDet models. These results were generated using the TFLite Benchmark Tool mentioned above.

* Device used - Pixel 4 (Inference timings are reported in milliseconds)
** As reported here

We can see that with the proper hardware accelerators, the DSP EdgeTPU variants can really shine. For the CPU variant, on a GPU accelerated runtime the float16 quantized TFLite model can bring in additional speed boosts.

A catch here is Pixel devices don’t allow third-party applications to use the Hexagon DSP therefore even if we instruct the Benchmark Tool to make use of that the model would fall back to the CPU for execution. This is why for fair benchmarking results for the DSP variants we should consider running the Benchmark Tool on a device (such as Samsung Galaxy S9+) that has Hexagon DSP and also allows third-party applications to use it.

* Device used - Samsung Galaxy S9+ (Inference timings are reported in milliseconds)
Note

To train a custom MobileDet-based object detector you can refer to these notebooks.

Conclusion

In this post, we discussed some of the intricate problems one may run into while converting different variants of the MobileDet model in TFLite. One aspect about TFLite that I really like is how it provides the tooling needed to deal with practical problems like this.

I am thankful to Khanh for thoroughly guiding me while writing this post. Thanks to Martin Andrews for suggesting textual edits.