CIFAR-10 Classification Using Mixture of Experts Skip to main content

CIFAR-10 Classification Using Mixture of Experts

We all recently heard that ChatGPT and GPT-3 were made based on an approach called Mixture of Experts (MoE). Such an approach has gained traction is the machine learning field which is a powerful paradigm that excels in handling complex, high-dimensional data. In this blog post, we embark on an enlightening step-by-step tutorial to develop, train, test, and validate a Mixture of Experts for the classification of images from the CIFAR-10 dataset
To implement MoE for image classification, we leverage the CIFAR-10 dataset, a benchmark in computer vision. With 60,000 32x32 color images across 10 classes, CIFAR-10 is a challenging playground to showcase the capabilities of MoE.
By the end of this story, you will understand the basics of a Mixture of Experts, and how to develop a MoE for basic and simple classification problems. 
P.S. This is not a very theoretical article. it is rather a How-To article on getting started with MoE for classification. 

Understanding Mixture of Experts:

Mixture of Experts is a neural network architecture that divides the learning task into multiple sub-tasks, assigning each to a specialized expert. These experts operate independently and contribute to the final prediction through a gating mechanism. This allows the model to adaptively choose which expert or combination of experts to rely on, enhancing its ability to handle diverse and intricate datasets.

Model Architecture and Design:

Our MoE architecture comprises expert networks responsible for handling specific features of the images. These experts work in parallel with a gating network, which learns to assign weights to each expert based on the input data. The collective decision made by the experts through the gating mechanism yields the final classification output.
One important aspect of building such a method is to select the right expert models artchitecture and select a good gating model too, Usually, this depends on the task itself. in this article, our experts will be simple convolutional neural networks while the gating model will be a simple feedforward network, to learn the gating weights.

1. Expert Networks:

Our MoE architecture comprises of 3 expert networks responsible for handling specific features of the images. For CIFAR-10, we design multiple expert networks, each specializing in recognizing distinctive patterns for a particular class. Convolutional layers and non-linear activation functions, such as ReLU, are used to capture hierarchical features.

2. Gating Network:

The gating network is important for selecting the most relevant expert for a given input. We employ a separate neural network, i.e., a simple feedforward network, to learn the gating weights. This network is trained concurrently with the expert networks to optimize the overall performance.

Training the Entire MoE

Training a MoE model involves optimizing both the expert networks and the gating network. We employ strategies such as layer-wise pre-training and curriculum learning to enhance convergence and stability. Regularization techniques, including dropout and weight decay, are applied to prevent overfitting.
during the training, the MoE model requires an iterative approach. in here, we will start by training expert networks independently and gradually introduce the gating network to fine-tune the interactions.

Implementing the MoE for CIFAR10 image Classification in TensorFlow

The code below shows the implementation of the experts, the gating model, the MoE, and finally the training process of the whole MoE model. 
import tensorflow as tf

from tensorflow.keras import layers

from tensorflow.keras.layers.experimental import preprocessing

from keras.utils import to_categorical

import numpy as np

# Load the CIFAR-10 dataset

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize pixel values between 0 and 1

x_train = x_train.astype("float32") / 255

x_test = x_test.astype("float32") / 255

# Convert labels to one-hot encoding

num_classes = 10  # classes for images

y_train = to_categorical(y_train, num_classes)

y_test= to_categorical(y_test, num_classes)

from keras import layers, models

# Expert Model

def create_expert_model(input_shape, num_outputs):

    model = models.Sequential([

        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),

        layers.MaxPooling2D((2, 2)),

        layers.Conv2D(64, (3, 3), activation='relu'),

        layers.MaxPooling2D((2, 2)),

        layers.Conv2D(128, (3, 3), activation='relu'),

        layers.MaxPooling2D((2, 2)),


        layers.Dense(128, activation='relu'),

        layers.Dense(num_outputs, activation='softmax')


    return model

# Gating Model

def create_gating_model(num_experts, num_outputs):

    model = models.Sequential([

        layers.Dense(64, activation='relu', input_shape=(num_experts,)),

        layers.Dense(num_outputs, activation='softmax')


    return model

# Mixture of Experts (MOE) Model

def create_moe_model(input_shape, num_experts, num_outputs):

    expert_input = layers.Input(shape=input_shape, name='expert_input')

    expert_model = create_expert_model(input_shape, num_outputs)

    expert_output = expert_model(expert_input)

    gating_input = layers.Input(shape=(num_experts,), name='gating_input')

    gating_model = create_gating_model(num_experts, num_outputs)

    gating_output = gating_model(gating_input)

    # Reshape gating output to match expert output shape

    gating_output = layers.Reshape((num_outputs, 1))(gating_output)

    gating_output = layers.Lambda(lambda x: layers.Flatten()(x))(gating_output)

    # Multiply expert and gating outputs

    mixture_output = layers.Multiply()([expert_output, gating_output])

    moe_model = models.Model(inputs=[expert_input, gating_input], outputs=mixture_output)

    return moe_model

# Example usage:

input_shape = (32, 32, 3)  # 

num_experts = 5  # Adjust based on the desired number of experts

num_outputs = 10  # Number of Cifar10 imageclasses

epochs = 30

moe_model = create_moe_model(input_shape, num_experts, num_outputs)

moe_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

#Train and Test

import matplotlib.pyplot as plt

history =

    [x_train, np.random.rand(len(x_train), num_experts)],




plt.plot(history.history['loss'], label='Training Loss')

plt.plot(history.history['val_loss'], label='Validation Loss')




# Evaluate the model on the test set
test_loss, test_accuracy = moe_model.evaluate([x_test, np.random.rand(len(x_test), num_experts)], y_test_one_hot)

print(f'Test Accuracy: {test_accuracy * 100:.2f}%')


Here is the learning process:
Epoch 1/35
782/782 [==============================] - 17s 21ms/step - loss: 1.5528 - accuracy: 0.4363
Epoch 2/35
782/782 [==============================] - 22s 28ms/step - loss: 1.1602 - accuracy: 0.5914
Epoch 3/35
782/782 [==============================] - 26s 34ms/step - loss: 1.0008 - accuracy: 0.6504
Epoch 4/35
782/782 [==============================] - 20s 25ms/step - loss: 0.8857 - accuracy: 0.6912
Epoch 5/35
782/782 [==============================] - 20s 25ms/step - loss: 0.7989 - accuracy: 0.7229
Epoch 6/35
782/782 [==============================] - 16s 21ms/step - loss: 0.7347 - accuracy: 0.7453
Epoch 7/35
782/782 [==============================] - 17s 22ms/step - loss: 0.6759 - accuracy: 0.7665
Epoch 8/35
782/782 [==============================] - 19s 25ms/step - loss: 0.6204 - accuracy: 0.7849
Epoch 9/35
782/782 [==============================] - 21s 27ms/step - loss: 0.5731 - accuracy: 0.7991
Epoch 10/35
782/782 [==============================] - 31s 39ms/step - loss: 0.5289 - accuracy: 0.8151
Epoch 11/35
782/782 [==============================] - 18s 23ms/step - loss: 0.4891 - accuracy: 0.8282
Epoch 12/35
782/782 [==============================] - 22s 29ms/step - loss: 0.4547 - accuracy: 0.8408
Epoch 13/35
782/782 [==============================] - 25s 32ms/step - loss: 0.4145 - accuracy: 0.8548
Epoch 14/35
782/782 [==============================] - 21s 27ms/step - loss: 0.3740 - accuracy: 0.8675
Epoch 15/35
782/782 [==============================] - 24s 30ms/step - loss: 0.3480 - accuracy: 0.8753
Epoch 16/35
782/782 [==============================] - 22s 28ms/step - loss: 0.3149 - accuracy: 0.8863
Epoch 17/35
782/782 [==============================] - 25s 32ms/step - loss: 0.2888 - accuracy: 0.8971
Epoch 18/35
782/782 [==============================] - 18s 23ms/step - loss: 0.2666 - accuracy: 0.9052
Epoch 19/35
782/782 [==============================] - 17s 21ms/step - loss: 0.2437 - accuracy: 0.9112
Epoch 20/35
782/782 [==============================] - 18s 23ms/step - loss: 0.2219 - accuracy: 0.9208
Epoch 21/35
782/782 [==============================] - 20s 26ms/step - loss: 0.1963 - accuracy: 0.9286
Epoch 22/35
782/782 [==============================] - 22s 28ms/step - loss: 0.1874 - accuracy: 0.9329
Epoch 23/35
782/782 [==============================] - 23s 30ms/step - loss: 0.1712 - accuracy: 0.9387
Epoch 24/35
782/782 [==============================] - 20s 26ms/step - loss: 0.1580 - accuracy: 0.9440
Epoch 25/35
782/782 [==============================] - 20s 25ms/step - loss: 0.1537 - accuracy: 0.9444
Epoch 26/35
782/782 [==============================] - 17s 22ms/step - loss: 0.1370 - accuracy: 0.9500
Epoch 27/35
782/782 [==============================] - 21s 27ms/step - loss: 0.1380 - accuracy: 0.9506
Epoch 28/35
782/782 [==============================] - 17s 22ms/step - loss: 0.1339 - accuracy: 0.9520
Epoch 29/35
782/782 [==============================] - 17s 22ms/step - loss: 0.1068 - accuracy: 0.9615
Epoch 30/35
782/782 [==============================] - 20s 26ms/step - loss: 0.1191 - accuracy: 0.9579
Epoch 31/35
782/782 [==============================] - 17s 22ms/step - loss: 0.1194 - accuracy: 0.9574
Epoch 32/35
782/782 [==============================] - 20s 25ms/step - loss: 0.1006 - accuracy: 0.9643
Epoch 33/35
782/782 [==============================] - 18s 23ms/step - loss: 0.1136 - accuracy: 0.9599
Epoch 34/35
782/782 [==============================] - 18s 23ms/step - loss: 0.1031 - accuracy: 0.9634
Epoch 35/35
782/782 [==============================] - 18s 23ms/step - loss: 0.1026 - accuracy: 0.9646

313/313 [==============================] - 2s 5ms/step - loss: 1.9688 - accuracy: 0.7066
Test Accuracy: 70.66%

If you like the article and would like to support me make sure to:
📰 View more content on my 
medium profile and 
👏Clap for this article
📰 View more content on AI-ContentLab Blog
🚀👉 Read more related articles to this one on Medium and AI-ContentLab


You may like

Latest Posts

SwiGLU Activation Function

Position Embedding: A Detailed Explanation

How to create a 1D- CNN in TensorFlow

Introduction to CNNs with Attention Layers

Meta Pseudo Labels (MPL) Algorithm

Video Classification Using CNN and Transformer: Hybrid Model

Liquid Neural Networks: Introduction