Supernet

In the field of machine learning, Supernet refers to a concept that involves training a single neural network architecture that can represent multiple related network architectures. It is often used in the context of neural architecture search (NAS), which is a technique for automatically discovering the optimal architecture for a given task.

Traditionally, in NAS, individual neural network architectures are evaluated and trained separately, which can be computationally expensive and time-consuming. Supernet offers a more efficient approach by creating a single network that encompasses a diverse set of architectural choices.

The Supernet consists of a combination of various architectural components, such as convolutional layers, pooling operations, skip connections, etc. Each component has multiple possible options or configurations. During training, the Supernet learns to assign weights or probabilities to different architectural choices, effectively learning to represent a wide range of network architectures within a single model.

Once the Supernet is trained, a process called architecture sampling is used to select specific network architectures from the Supernet. This involves applying a search algorithm, such as reinforcement learning or evolutionary algorithms, to find the optimal combination of architectural choices for a given task or dataset.

By using a Supernet, the search process for finding the best architecture is significantly accelerated, as it eliminates the need to train each architecture from scratch. The Supernet acts as a compact representation of a large search space of possible architectures, enabling faster and more efficient exploration.

Overall, Supernet is a technique that enables efficient neural architecture search by training a single network that can represent multiple network architectures, providing a more streamlined approach to discovering optimal models for various machine learning tasks.

[1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Input

# Load and preprocess the MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()
input_shape = input_train.shape[1:]
num_classes = len(set(target_train))

input_train = input_train.reshape(input_train.shape[0], *input_shape, 1)
input_test = input_test.reshape(input_test.shape[0], *input_shape, 1)
target_train = tf.keras.utils.to_categorical(target_train, num_classes)
target_test = tf.keras.utils.to_categorical(target_test, num_classes)

# Regular model without Supernet
regular_model = tf.keras.Sequential([
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape+(1,)),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(num_classes, activation='softmax')
])

regular_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

regular_model.fit(input_train, target_train, batch_size=64, epochs=10, validation_data=(input_test, target_test))

regular_model_score = regular_model.evaluate(input_test, target_test, verbose=0)
print("Regular Model Test Accuracy:", regular_model_score[-1])

# Supernet model
def build_supernet(input_shape, num_classes, num_architectural_choices):
    inputs = Input(shape=input_shape + (1,))
    backbone = Conv2D(32, kernel_size=(3, 3), activation='relu')(inputs)
    backbone = MaxPooling2D(pool_size=(2, 2))(backbone)
    backbone = Flatten()(backbone)

    choices = []
    for i in range(num_architectural_choices):
        choice = Dense(64, activation='relu')(backbone)
        choice = Dense(num_classes, activation='softmax')(choice)
        choices.append(choice)

    supernet_model = Model(inputs=inputs, outputs=choices)
    return supernet_model

num_architectural_choices = 3
supernet_model = build_supernet(input_shape, num_classes, num_architectural_choices)

supernet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
supernet_model.fit(input_train, [target_train] * num_architectural_choices, batch_size=64, epochs=10, validation_data=(input_test, [target_test] * num_architectural_choices))

supernet_model_score = supernet_model.evaluate(input_test, [target_test] * num_architectural_choices, verbose=0)
print("Supernet Model Test Accuracy:", supernet_model_score[-1])
2023-06-14 10:26:13.358532: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-14 10:26:13.802667: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Epoch 1/10
2023-06-14 10:26:14.599913: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-06-14 10:26:14.615667: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
938/938 [==============================] - 3s 3ms/step - loss: 0.7238 - accuracy: 0.9334 - val_loss: 0.1532 - val_accuracy: 0.9615
Epoch 2/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0911 - accuracy: 0.9746 - val_loss: 0.0957 - val_accuracy: 0.9736
Epoch 3/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0603 - accuracy: 0.9815 - val_loss: 0.1234 - val_accuracy: 0.9688
Epoch 4/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0531 - accuracy: 0.9834 - val_loss: 0.1156 - val_accuracy: 0.9718
Epoch 5/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0444 - accuracy: 0.9864 - val_loss: 0.1225 - val_accuracy: 0.9725
Epoch 6/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0448 - accuracy: 0.9866 - val_loss: 0.1105 - val_accuracy: 0.9743
Epoch 7/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0439 - accuracy: 0.9872 - val_loss: 0.1013 - val_accuracy: 0.9787
Epoch 8/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0365 - accuracy: 0.9891 - val_loss: 0.1461 - val_accuracy: 0.9742
Epoch 9/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0417 - accuracy: 0.9885 - val_loss: 0.1424 - val_accuracy: 0.9767
Epoch 10/10
938/938 [==============================] - 3s 3ms/step - loss: 0.0318 - accuracy: 0.9909 - val_loss: 0.1335 - val_accuracy: 0.9765
Regular Model Test Accuracy: 0.9764999747276306
Epoch 1/10
938/938 [==============================] - 7s 7ms/step - loss: 1.7017 - dense_2_loss: 0.6054 - dense_4_loss: 0.5263 - dense_6_loss: 0.5700 - dense_2_accuracy: 0.9118 - dense_4_accuracy: 0.9221 - dense_6_accuracy: 0.9056 - val_loss: 0.3930 - val_dense_2_loss: 0.1290 - val_dense_4_loss: 0.1295 - val_dense_6_loss: 0.1345 - val_dense_2_accuracy: 0.9651 - val_dense_4_accuracy: 0.9658 - val_dense_6_accuracy: 0.9631
Epoch 2/10
938/938 [==============================] - 6s 6ms/step - loss: 0.2655 - dense_2_loss: 0.0891 - dense_4_loss: 0.0865 - dense_6_loss: 0.0899 - dense_2_accuracy: 0.9740 - dense_4_accuracy: 0.9748 - dense_6_accuracy: 0.9742 - val_loss: 0.2734 - val_dense_2_loss: 0.0874 - val_dense_4_loss: 0.0903 - val_dense_6_loss: 0.0958 - val_dense_2_accuracy: 0.9742 - val_dense_4_accuracy: 0.9730 - val_dense_6_accuracy: 0.9739
Epoch 3/10
938/938 [==============================] - 6s 6ms/step - loss: 0.1663 - dense_2_loss: 0.0562 - dense_4_loss: 0.0544 - dense_6_loss: 0.0557 - dense_2_accuracy: 0.9826 - dense_4_accuracy: 0.9831 - dense_6_accuracy: 0.9831 - val_loss: 0.2448 - val_dense_2_loss: 0.0795 - val_dense_4_loss: 0.0832 - val_dense_6_loss: 0.0820 - val_dense_2_accuracy: 0.9763 - val_dense_4_accuracy: 0.9768 - val_dense_6_accuracy: 0.9761
Epoch 4/10
938/938 [==============================] - 6s 6ms/step - loss: 0.1256 - dense_2_loss: 0.0425 - dense_4_loss: 0.0392 - dense_6_loss: 0.0439 - dense_2_accuracy: 0.9866 - dense_4_accuracy: 0.9876 - dense_6_accuracy: 0.9865 - val_loss: 0.2405 - val_dense_2_loss: 0.0774 - val_dense_4_loss: 0.0815 - val_dense_6_loss: 0.0816 - val_dense_2_accuracy: 0.9803 - val_dense_4_accuracy: 0.9776 - val_dense_6_accuracy: 0.9793
Epoch 5/10
938/938 [==============================] - 6s 6ms/step - loss: 0.1043 - dense_2_loss: 0.0335 - dense_4_loss: 0.0322 - dense_6_loss: 0.0386 - dense_2_accuracy: 0.9893 - dense_4_accuracy: 0.9896 - dense_6_accuracy: 0.9879 - val_loss: 0.2638 - val_dense_2_loss: 0.0864 - val_dense_4_loss: 0.0834 - val_dense_6_loss: 0.0940 - val_dense_2_accuracy: 0.9772 - val_dense_4_accuracy: 0.9765 - val_dense_6_accuracy: 0.9757
Epoch 6/10
938/938 [==============================] - 6s 6ms/step - loss: 0.0892 - dense_2_loss: 0.0308 - dense_4_loss: 0.0286 - dense_6_loss: 0.0298 - dense_2_accuracy: 0.9899 - dense_4_accuracy: 0.9908 - dense_6_accuracy: 0.9907 - val_loss: 0.2772 - val_dense_2_loss: 0.0839 - val_dense_4_loss: 0.0887 - val_dense_6_loss: 0.1046 - val_dense_2_accuracy: 0.9796 - val_dense_4_accuracy: 0.9773 - val_dense_6_accuracy: 0.9763
Epoch 7/10
938/938 [==============================] - 6s 6ms/step - loss: 0.0787 - dense_2_loss: 0.0266 - dense_4_loss: 0.0241 - dense_6_loss: 0.0280 - dense_2_accuracy: 0.9913 - dense_4_accuracy: 0.9919 - dense_6_accuracy: 0.9905 - val_loss: 0.2657 - val_dense_2_loss: 0.0913 - val_dense_4_loss: 0.0834 - val_dense_6_loss: 0.0911 - val_dense_2_accuracy: 0.9786 - val_dense_4_accuracy: 0.9803 - val_dense_6_accuracy: 0.9786
Epoch 8/10
938/938 [==============================] - 6s 6ms/step - loss: 0.0686 - dense_2_loss: 0.0215 - dense_4_loss: 0.0222 - dense_6_loss: 0.0248 - dense_2_accuracy: 0.9930 - dense_4_accuracy: 0.9924 - dense_6_accuracy: 0.9925 - val_loss: 0.2868 - val_dense_2_loss: 0.0863 - val_dense_4_loss: 0.0960 - val_dense_6_loss: 0.1045 - val_dense_2_accuracy: 0.9810 - val_dense_4_accuracy: 0.9782 - val_dense_6_accuracy: 0.9785
Epoch 9/10
938/938 [==============================] - 6s 6ms/step - loss: 0.0575 - dense_2_loss: 0.0207 - dense_4_loss: 0.0190 - dense_6_loss: 0.0178 - dense_2_accuracy: 0.9933 - dense_4_accuracy: 0.9941 - dense_6_accuracy: 0.9939 - val_loss: 0.3056 - val_dense_2_loss: 0.1014 - val_dense_4_loss: 0.0918 - val_dense_6_loss: 0.1125 - val_dense_2_accuracy: 0.9783 - val_dense_4_accuracy: 0.9785 - val_dense_6_accuracy: 0.9765
Epoch 10/10
938/938 [==============================] - 6s 6ms/step - loss: 0.0463 - dense_2_loss: 0.0169 - dense_4_loss: 0.0137 - dense_6_loss: 0.0158 - dense_2_accuracy: 0.9946 - dense_4_accuracy: 0.9955 - dense_6_accuracy: 0.9951 - val_loss: 0.3307 - val_dense_2_loss: 0.1123 - val_dense_4_loss: 0.0951 - val_dense_6_loss: 0.1233 - val_dense_2_accuracy: 0.9778 - val_dense_4_accuracy: 0.9804 - val_dense_6_accuracy: 0.9789
Supernet Model Test Accuracy: 0.9789000153541565