Train your first gMLP Model for classifying images
My last post discussed the MLPMixer model which came out achieving SOTA performance to Transformers-based Vision models. In this post, we will talk a about a similar research which also came out in May 2021: Pay Attention to MLPs which proposes the gMLP architecture and achieves SOTA performance in key language and vision tasks. In a normal Transformer, we have over 60% of MLPs and rest are complex computations needed by Self Attention layer; however, in gMLP, MLPs are used throughout with gating (hence, gMLP). Talking about timings, Imagenet classification models take about 1-4 hours in training for TPUv3 with 128 cores and for MLM models (Masked Language Modeling) on full BERT setup takes 1-5 days in training. Just like for MLPMixers, gMLP model when made substantially larger can close the gap with Transformers in tasks where it performs worse.
According to the below image, gMLP model outperforms Transformer-based models but not ConvNets models (because of course!).

Let's look at how gMLP model works!

gMLP is an MLP-based alternative to Transformers without self-attention, which simply consists of channel projections and spatial projections with static parameterization. It is built out of basic MLP layers with gating. The model consists of a stack of blocks with identical size and structure.
A key ingredient is s(.)
, a layer which captures spatial interactions. When is an identity mapping, the above transformation degenerates to a regular FFN, where individual tokens are processed independently without any cross-token communication. One of the major focuses is therefore to design a good capable of capturing complex spatial interactions across tokens. This leads to the use of a Spatial Gating Unit which involves a modified linear gating.
Let's get into the tutorial!
In this tutorial, we are going to train the gMLP model on CIFAR-100 dataset. We will build the model using Keras and train it on Kaggle using GPU acceleration such that each epoch will take around 15 seconds
.
Step 1: Setup the dataset and helper code
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
weight_decay = 0.0001
batch_size = 128
num_epochs = 100
dropout_rate = 0.2
image_size = 32 # We'll resize input images to this size.
patch_size = 4 # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
embedding_dim = 128 # Number of hidden units.
num_blocks = 4 # Number of blocks.
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
class Patches(layers.Layer):
def __init__(self, patch_size, num_patches):
super(Patches, self).__init__()
self.patch_size = patch_size
self.num_patches = num_patches
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
return patches
Step 2: Define the gMLP Model
class gMLPLayer(layers.Layer):
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
super(gMLPLayer, self).__init__(*args, **kwargs)
self.channel_projection1 = keras.Sequential(
[
layers.Dense(units=embedding_dim * 2),
tfa.layers.GELU(),
layers.Dropout(rate=dropout_rate),
]
)
self.channel_projection2 = layers.Dense(units=embedding_dim)
self.spatial_projection = layers.Dense(
units=num_patches, bias_initializer="Ones"
)
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
def spatial_gating_unit(self, x):
# Split x along the channel dimensions.
# Tensors u and v will in th shape of [batch_size, num_patchs, embedding_dim].
u, v = tf.split(x, num_or_size_splits=2, axis=2)
# Apply layer normalization.
v = self.normalize2(v)
# Apply spatial projection.
v_channels = tf.linalg.matrix_transpose(v)
v_projected = self.spatial_projection(v_channels)
v_projected = tf.linalg.matrix_transpose(v_projected)
# Apply element-wise multiplication.
return u * v_projected
def call(self, inputs):
# Apply layer normalization.
x = self.normalize1(inputs)
# Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
x_projected = self.channel_projection1(x)
# Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
x_spatial = self.spatial_gating_unit(x_projected)
# Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
x_projected = self.channel_projection2(x_spatial)
# Add skip connection.
return x + x_projected
Step 3: Build the Classifier
This function will take the MLPMixer
model as an argument and return the final model to be trained. What it's doing is basically, putting all the layers together:
- Take input data
- Pass it onto the
Patches
layer - Pass the
patches
to MLPMixer model - Add the
pooling
layers followed by aDropout
layer followed by the finalDense
layer
def build_classifier(blocks):
inputs = layers.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size, num_patches)(augmented)
# Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
x = layers.Dense(units=embedding_dim)(patches)
# Process x using the module blocks.
x = blocks(x)
# Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
representation = layers.GlobalAveragePooling1D()(x)
# Apply dropout.
representation = layers.Dropout(rate=dropout_rate)(representation)
# Compute logits outputs.
logits = layers.Dense(num_classes)(representation)
# Create the Keras model.
return keras.Model(inputs=inputs, outputs=logits)
Step 4: Build and Train the Model
gmlp_blocks = keras.Sequential(
[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay,)
# Compile the model.
gmlp_classifier.compile(optimizer=optimizer, loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
])
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
# Fit the model.
history = gmlp_classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs, validation_split=0.1, callbacks=[early_stopping, reduce_lr])
Step 5: Testing
_, accuracy, top_5_accuracy = gmlp_classifier.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
In Step 4
, we took top-5 accuracy
as an additional metric besides categorical accuracy
and trained the model for 100 epochs
. Let's have a look at the results.

So, we reached a top-5
accuracy of 78.91%
which means that 81% of times the actual label of the image (out of 100 labels) was in one of the top-5 predicted labels.
[Optional]
Checkout the public notebook here:
The code used in this tutorial is a boiled down version from Khalid Salama's original Keras examples post.
[Bonus]
Checkout an explanatory video on gMLP on YouTube.
Cheers!
P.S. - There haven't been any blog posts in the past 2 months but I am getting back on the wagon! Things are moving in a certain direction from March last week :)