Train your first MLPMixer model for classifying images
Whenever we think about solving a problem in Computer Vision, we think Convolutional Neural Networks. But then Vision Transformer came and things changed. ViT employed a Transformer-like architecture over patches
of the image. An image is split into fixed-size patches
, each of them are then linearly embedded, position embeddings are added, and the resulting sequence of vectors is fed to a standard Transformer encoder. Pretty straight forward, right?
But then in May 2021, the MLPMixer paper came out which pointed out that you can reach comparable performance (to state-of-the-art like ViT, ResNet-like models on downstream CV tasks) with Multilayer Perceptrons (MLPs) only. This came as a breakthrough because (given a large amount of data and compute power) the team at Google Brain reached state-of-the-art Vision Transformer level performance in a consistently less amount of time.
Paper's Experiment Results:
When pre-trained on ImageNet-21k data with extra regularization:
- Mixer Large reaches
93.63%
top-1 accuracy average performance across all five downstream tasks : ImageNet, CIFAR-10, CIFAR-100, Pets and Flowers. In comparison, ViT has94.49%
top-1 in the same category. The real thing to notice is theimg/sec/core
throughput which105
for mixer and32
for Mixer and ViT respectively. Impressive. - Mixer reaches
74.95%
accuracy as compared to ViT's72.72%
on Visual Task Adaptation Benchmark (VTAB-1k), which consists of 19 diverse datasets, each with 1k training examples. - The real thing to notice is the Mixer Large's throughput
105 img/sec/core
which is higher than ViT Large's32 img/sec/core
. This is the thing that makes it the most exciting.

Now, let's look at how the MLPMixer works!

MLPMixer consists of per-patch
linear embeddings, Mixer layers, and a classifier head. The input to the network is in form of image patches
. These patches
are projected linearly into an H-dimension
latent representation (Hidden layers) and passed on to the Mixer layer. One thing to note here is the H
value is independent of the number of patches
or patch sizes
which enabled the network to grow linearly instead of quadratically in the case of ViT. This resulted in reduced computational parameters and a higher throughput of about 120 images/sec/core
.
Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and layer norm on the channels.
- The channel-mixing MLPs allow communication between different channels; they operate on each token independently and take individual rows of the table as inputs.
- The token-mixing MLPs allow communication between different spatial locations (tokens); they operate on each channel independently and take individual columns of the table as inputs.
These two types of layers are interleaved to enable interaction of both input dimensions.
Let's get into the tutorial!
In this tutorial, we are going to train the MLPMixer model on CIFAR-100 dataset. We will build the model using Keras and train it on Kaggle using GPU acceleration such that each epoch will take around 15 seconds
.
Step 1: Setup the dataset and helper code
- Download the dataset, it comes with pre-loaded
train
andtest
sets. - Setup a
data_augmentation
layer to increase the diversity of data available which acts as a regularizer and helps reduce overfitting during training. - Define a
Patches
layer to be used, since the model is going to be trained onpatches
.
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
weight_decay = 0.0001
batch_size = 128
num_epochs = 100
dropout_rate = 0.2
image_size = 32 # We'll resize input images to this size.
patch_size = 4 # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
embedding_dim = 128 # Number of hidden units.
num_blocks = 4 # Number of blocks.
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
class Patches(layers.Layer):
def __init__(self, patch_size, num_patches):
super(Patches, self).__init__()
self.patch_size = patch_size
self.num_patches = num_patches
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
return patches
Step 2: Define the MLPMixer model
class MLPMixerLayer(layers.Layer):
def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
super(MLPMixerLayer, self).__init__(*args, **kwargs)
self.mlp1 = keras.Sequential(
[
layers.Dense(units=num_patches),
tfa.layers.GELU(),
layers.Dense(units=num_patches),
layers.Dropout(rate=dropout_rate),
]
)
self.mlp2 = keras.Sequential(
[
layers.Dense(units=num_patches),
tfa.layers.GELU(),
layers.Dense(units=embedding_dim),
layers.Dropout(rate=dropout_rate),
]
)
self.normalize = layers.LayerNormalization(epsilon=1e-6)
def call(self, inputs):
# Apply layer normalization.
x = self.normalize(inputs)
# Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
x_channels = tf.linalg.matrix_transpose(x)
# Apply mlp1 on each channel independently.
mlp1_outputs = self.mlp1(x_channels)
# Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)
# Add skip connection.
x = mlp1_outputs + inputs
# Apply layer normalization.
x_patches = self.normalize(x)
# Apply mlp2 on each patch independtenly.
mlp2_outputs = self.mlp2(x_patches)
# Add skip connection.
x = x + mlp2_outputs
return x
Step 3: Build the Classifier
This function will take the MLPMixer
model as an argument and return the final model to be trained. What it's doing is basically, putting all the layers together:
- Take input data
- Pass it onto the
Patches
layer - Pass the
patches
to MLPMixer model - Add the
pooling
layers followed by aDropout
layer followed by the finalDense
layer
def build_classifier(blocks):
inputs = layers.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size, num_patches)(augmented)
# Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
x = layers.Dense(units=embedding_dim)(patches)
# Process x using the module blocks.
x = blocks(x)
# Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
representation = layers.GlobalAveragePooling1D()(x)
# Apply dropout.
representation = layers.Dropout(rate=dropout_rate)(representation)
# Compute logits outputs.
logits = layers.Dense(num_classes)(representation)
# Create the Keras model.
return keras.Model(inputs=inputs, outputs=logits)
Step 4: Build and Train the Model
mlpmixer_blocks = keras.Sequential(
[MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay,)
# Compile the model.
mlpmixer_classifier.compile(optimizer=optimizer, loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
])
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
# Fit the model.
history = mlpmixer_classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs, validation_split=0.1, callbacks=[early_stopping, reduce_lr])
Step 5: Testing
_, accuracy, top_5_accuracy = mlpmixer_classifier.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
In Step 4
, we took top-5 accuracy
as an additional metric besides categorical accuracy
and trained the model for 100 epochs
. Let's have a look at the results.

So, we reached a top-5
accuracy of 81.07%
which means that 81% of times the actual label of the image (out of 100 labels) was in one of the top-5 predicted labels. And YES, we just trained a MLP to recognize images out of a 100 distinct labels dataset. (You can also do the same with MNIST, but doing this on a dataset as diverse as CIFAR-100 is very crucial.)
[Optional]
Checkout the public notebook here:
The code used in this tutorial is a boiled down version from Khalid Salama's original Keras examples post.
[Bonus]
If you think this was cool, check this out: (They used MLPs
to obtain SOTA comparable performance on both downstream CV and NLP tasks)

Cheers!