Train your first MLPMixer model for classifying images

Computer Vision Feb 26, 2022

Whenever we think about solving a problem in Computer Vision, we think Convolutional Neural Networks. But then Vision Transformer came and things changed. ViT employed a Transformer-like architecture over patches of the image. An image is split into fixed-size patches, each of them are then linearly embedded, position embeddings are added, and the resulting sequence of vectors is fed to a standard Transformer encoder. Pretty straight forward, right?

Photo by Jeffery Ho / Unsplash [Pun Intended] 

But then in May 2021, the MLPMixer paper came out which pointed out that you can reach comparable performance (to state-of-the-art like ViT, ResNet-like models on downstream CV tasks) with Multilayer Perceptrons (MLPs) only. This came as a breakthrough because (given a large amount of data and compute power) the team at Google Brain reached state-of-the-art Vision Transformer level performance in a consistently less amount of time.

Paper's Experiment Results:

When pre-trained on ImageNet-21k data with extra regularization:

  1. Mixer Large reaches 93.63% top-1 accuracy average performance across all five downstream tasks : ImageNet, CIFAR-10, CIFAR-100, Pets and Flowers. In comparison, ViT has 94.49% top-1 in the same category. The real thing to notice is the img/sec/core throughput which 105 for mixer and 32 for Mixer and ViT respectively. Impressive.
  2. Mixer reaches 74.95% accuracy as compared to ViT's 72.72% on Visual Task Adaptation Benchmark (VTAB-1k), which consists of 19 diverse datasets, each with 1k training examples.
  3. The real thing to notice is the Mixer Large's throughput 105 img/sec/core which is higher than ViT Large's 32 img/sec/core. This is the thing that makes it the most exciting.
Performance Comparison of MLPMixer Model

Now, let's look at how the MLPMixer works!

MLPMixer architecture

MLPMixer consists of per-patch linear embeddings, Mixer layers, and a classifier head.  The input to the network is in form of image patches. These patches are projected linearly into an H-dimension latent representation (Hidden layers) and passed on to the Mixer layer. One thing to note here is the H value is independent of the number of patches or patch sizes which enabled the network to grow linearly instead of quadratically in the case of ViT. This resulted in reduced computational parameters and a higher throughput of about 120 images/sec/core.

Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and layer norm on the channels.

  1. The channel-mixing MLPs allow communication between different channels; they operate on each token independently and take individual rows of the table as inputs.
  2. The token-mixing MLPs allow communication between different spatial locations (tokens); they operate on each channel independently and take individual columns of the table as inputs.

These two types of layers are interleaved to enable interaction of both input dimensions.

Let's get into the tutorial!

In this tutorial, we are going to train the MLPMixer model on CIFAR-100 dataset. We will build the model using Keras and train it on Kaggle using GPU acceleration such that each epoch will take around 15 seconds.

Step 1: Setup the dataset and helper code

  1. Download the dataset, it comes with pre-loaded train and test sets.
  2. Setup a data_augmentation layer to increase the diversity of data available which acts as a regularizer and helps reduce overfitting during training.
  3. Define a Patches layer to be used, since the model is going to be trained on patches.
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

weight_decay = 0.0001
batch_size = 128
num_epochs = 100
dropout_rate = 0.2
image_size = 32  # We'll resize input images to this size.
patch_size = 4  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 128  # Number of hidden units.
num_blocks = 4  # Number of blocks.

data_augmentation = keras.Sequential(
        layers.Resizing(image_size, image_size),
            height_factor=0.2, width_factor=0.2
# Compute the mean and the variance of the training data for normalization.

class Patches(layers.Layer):
    def __init__(self, patch_size, num_patches):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
        return patches

Step 2: Define the MLPMixer model

class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super(MLPMixerLayer, self).__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
        self.mlp2 = keras.Sequential(
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = tf.linalg.matrix_transpose(x)
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

Step 3: Build the Classifier

This function will take the MLPMixer model as an argument and return the final model to be trained. What it's doing is basically, putting all the layers together:

  1. Take input data
  2. Pass it onto the Patches layer
  3. Pass the patches to MLPMixer model
  4. Add the pooling layers followed by a Dropout layer followed by the final Dense layer
def build_classifier(blocks):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size, num_patches)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

Step 4: Build and Train the Model

mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)

optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay,)
# Compile the model.
mlpmixer_classifier.compile(optimizer=optimizer, loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
# Fit the model.
history =, y=y_train, batch_size=batch_size, epochs=num_epochs, validation_split=0.1, callbacks=[early_stopping, reduce_lr])

Step 5: Testing

_, accuracy, top_5_accuracy = mlpmixer_classifier.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

In Step 4, we took top-5 accuracy as an additional metric besides categorical accuracy and trained the model for 100 epochs. Let's have a look at the results.

So, we reached a top-5 accuracy of 81.07% which means that 81% of times the actual label of the image (out of 100 labels) was in one of the top-5 predicted labels. And YES, we just trained a MLP to recognize images out of a 100 distinct labels dataset. (You can also do the same with MNIST, but doing this on a dataset as diverse as CIFAR-100 is very crucial.)


Checkout the public notebook here:

The code used in this tutorial is a boiled down version from Khalid Salama's original Keras examples post.


If you think this was cool, check this out: (They used MLPs to obtain SOTA comparable performance on both downstream CV and NLP tasks)

Pay Attention to MLPs
Transformers have become one of the most important architectural innovationsin deep learning and have enabled many breakthroughs over the past few years.Here we propose a simple network architecture, gMLP, based on MLPs with gating,and show that it can perform as well as Transformers in key langu…