Examining distributed training of Keras Models
Creating deep learning models using Keras is pretty straightforward, which is why Keras is often used for prototyping and creating proof-of-concept products. But when it comes to using it for training bigger models or using very big datasets, we need to either split the dataset or the model and distribute the training, and/or the inference into multiple devices and possibly over multiple machines, which Keras partially supported on “Keras.io” repository. One option would be to use a Python generator to do model training, but that might take a lot of time. The whole process of data ingestion might be bottlenecked by limited I/O throughput as data would be read from disk or possibly from a bucket in a cloud environment.
What is distributed training?Previously, after you created a Keras model, you needed to convert it to Estimator (tf.keras.estimator.model_to_estimator) and adapt the data feeding mechanism required by the Estimator API - which is different from the way you feed data to a Keras model. With the adoption of the Keras framework as official high-level API for TensorFlow, it became highly integrated in the whole TensorFlow framework - which includes the ability to train a Keras model on multiple GPUs, TPUs, on multiple machines (containing more GPUs), and even on TPU pods. This, together with the adoption of the TF Dataset API, lets us use Keras with large datasets which would not fit in memory. This is done by using distributed training without needing to convert the model to the Estimator API. The use of the Dataset API also enables efficient use of training on TPU. It enables the TPU to directly ingest the data from Google Cloud Storage, instead of data being first transferred to the machine where the training was initiated (where the Python script is running). Estimator API is still supported, and is still required in some cases. For example, if you use TensorFlowExtended (TFX) then your model needs to be an Estimator. In order to distribute the training of a Keras model, we can use one of the distribution strategies explained below. Bear in mind that not all of these strategies are officially supported at the time of writing, and may be available in TensorFlow nightly builds. During this testing, I used TF 2.0 alpha.
- Mirrored Strategy
- MultiWorker Mirrored Strategy
- Parameter Server Strategy
- TPU Strategy