Knowledge Distillation with ResNet for Raw Image Classification


Recently, I conducted a simple experiment on a raw image classification task, comparing the accuracy between three setups: baseline without augmentation, ResNet50 with augmentation, and ResNet18 with distillation and augmentation. The distillation improves Resnet18 significantly.

Dataset: phelber/EuroSAT: EuroSAT - Land Use and Land Cover Classification with Sentinel-2 This dataset contains .tif image files. The EuroSAT_MS.zip includes the multi-spectral version of the EuroSAT dataset, which contains all 13 Sentinel-2 bands in their original value range.

Experiment Setup:

  • Dataset train&test spilt: 8 : 2. Exactly same image data for each train.
  • First, I ran ResNet18 with input modified to 32×32×13, without any augmentation or optimization tricks.
  • Then, I trained ResNet50 on the same data, also without augmentation. The model quickly overfit: training accuracy reached 98%, but test accuracy was only 85%.
  • After adding data augmentation (e.g., random flips and other adjustments), the test accuracy improved to 93.09%.
  • Finally, I trained ResNet18 (same architecture, initialized from scratch) using knowledge distillation from the ResNet50 teacher. I computed the loss using both soft labels and hard labels. After only 10 epochs, the test accuracy reached 96.65%.

Experiment:

Network

Resnet18 BaseLine

Resnet50 Teacher

Resnet18 distilation&opt

Train Epoch

10

20

10

Test accurcy

89.57%

93.09%

96.65%

Models:

#Resnet50
 model = models.resnet50(weights=None)
    model.conv1 = nn.Conv2d(channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.bn1 = nn.BatchNorm2d(64)
    model.fc = nn.Linear(model.fc.in_features, classes)
    model = model.to(device)

#Resnet18
model = models.resnet18(weights=None)

    model.conv1 = nn.Conv2d(channel, 32, kernel_size=3, stride=1, padding=1, bias=False)
    model.bn1 = nn.BatchNorm2d(32)

    model.layer1[0].conv1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.layer1[0].bn1 = nn.BatchNorm2d(64)
    model.layer1[0].downsample = nn.Sequential(
        nn.Conv2d(32, 64, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm2d(64)
    )

    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, classes)
    model = model.to(device)

Find more details on my Github.