Train your first gMLP Model for classifying images

Computer Vision May 8, 2022

My last post discussed the MLPMixer model which came out achieving SOTA performance to Transformers-based Vision models. In this post, we will talk a about a similar research which also came out in May 2021: Pay Attention to MLPs which proposes the gMLP architecture and achieves SOTA performance in key language and vision tasks. In a normal Transformer, we have over 60% of MLPs and rest are complex computations needed by Self Attention layer; however, in gMLP, MLPs are used throughout with gating (hence, gMLP). Talking about timings, Imagenet classification models take about 1-4 hours in training for TPUv3 with 128 cores and for MLM models (Masked Language Modeling) on full BERT setup takes 1-5 days in training. Just like for MLPMixers, gMLP model when made substantially larger can close the gap with Transformers in tasks where it performs worse.

According to the below image, gMLP model outperforms Transformer-based models but not ConvNets models (because of course!).  

Performance Comparison of gMLP Model

Let's look at how gMLP model works!

gMLP Block

gMLP is an MLP-based alternative to Transformers without self-attention, which simply consists of channel projections and spatial projections with static parameterization. It is built out of basic MLP layers with gating. The model consists of a stack of  blocks with identical size and structure.

A key ingredient is s(.), a layer which captures spatial interactions. When  is an identity mapping, the above transformation degenerates to a regular FFN, where individual tokens are processed independently without any cross-token communication. One of the major focuses is therefore to design a good  capable of capturing complex spatial interactions across tokens. This leads to the use of a Spatial Gating Unit which involves a modified linear gating.

Let's get into the tutorial!

In this tutorial, we are going to train the gMLP model on CIFAR-100 dataset. We will build the model using Keras and train it on Kaggle using GPU acceleration such that each epoch will take around 15 seconds.

Step 1: Setup the dataset and helper code

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

weight_decay = 0.0001
batch_size = 128
num_epochs = 100
dropout_rate = 0.2
image_size = 32  # We'll resize input images to this size.
patch_size = 4  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 128  # Number of hidden units.
num_blocks = 4  # Number of blocks.

data_augmentation = keras.Sequential(
        layers.Resizing(image_size, image_size),
            height_factor=0.2, width_factor=0.2
# Compute the mean and the variance of the training data for normalization.

class Patches(layers.Layer):
    def __init__(self, patch_size, num_patches):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
        return patches

Step 2: Define the gMLP Model

class gMLPLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super(gMLPLayer, self).__init__(*args, **kwargs)

        self.channel_projection1 = keras.Sequential(
                layers.Dense(units=embedding_dim * 2),

        self.channel_projection2 = layers.Dense(units=embedding_dim)

        self.spatial_projection = layers.Dense(
            units=num_patches, bias_initializer="Ones"

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def spatial_gating_unit(self, x):
        # Split x along the channel dimensions.
        # Tensors u and v will in th shape of [batch_size, num_patchs, embedding_dim].
        u, v = tf.split(x, num_or_size_splits=2, axis=2)
        # Apply layer normalization.
        v = self.normalize2(v)
        # Apply spatial projection.
        v_channels = tf.linalg.matrix_transpose(v)
        v_projected = self.spatial_projection(v_channels)
        v_projected = tf.linalg.matrix_transpose(v_projected)
        # Apply element-wise multiplication.
        return u * v_projected

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize1(inputs)
        # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
        x_projected = self.channel_projection1(x)
        # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
        x_spatial = self.spatial_gating_unit(x_projected)
        # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
        x_projected = self.channel_projection2(x_spatial)
        # Add skip connection.
        return x + x_projected

Step 3: Build the Classifier

This function will take the MLPMixer model as an argument and return the final model to be trained. What it's doing is basically, putting all the layers together:

  1. Take input data
  2. Pass it onto the Patches layer
  3. Pass the patches to MLPMixer model
  4. Add the pooling layers followed by a Dropout layer followed by the final Dense layer
def build_classifier(blocks):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size, num_patches)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

Step 4: Build and Train the Model

gmlp_blocks = keras.Sequential(
    [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)

optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay,)
# Compile the model.
gmlp_classifier.compile(optimizer=optimizer, loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
    # Fit the model.
history =, y=y_train, batch_size=batch_size, epochs=num_epochs, validation_split=0.1, callbacks=[early_stopping, reduce_lr])

Step 5: Testing

_, accuracy, top_5_accuracy = gmlp_classifier.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

In Step 4, we took top-5 accuracy as an additional metric besides categorical accuracy and trained the model for 100 epochs. Let's have a look at the results.

So, we reached a top-5 accuracy of 78.91% which means that 81% of times the actual label of the image (out of 100 labels) was in one of the top-5 predicted labels.


Checkout the public notebook here:

The code used in this tutorial is a boiled down version from Khalid Salama's original Keras examples post.


Checkout an explanatory video on gMLP on YouTube.  


P.S. - There haven't been any blog posts in the past 2 months but I am getting back on the wagon! Things are moving in a certain direction from March last week :)