Skip to main content

Image Classification: Fine-Tune ResNet18 on Kaggle Dataset (PyTorch + Lightning)

Image Classification: Fine-Tuning ResNet-18 on Kaggle's Lions vs Cheetahs Dataset

Image classification is a fundamental task in computer vision where the goal is to assign a label or class to an input image. It is widely used in various domains such as medical imaging, autonomous driving, wildlife monitoring, and security. A typical image classification pipeline involves feeding an image into a neural network model, which processes the input and outputs class probabilities corresponding to predefined categories.

What is the ImageNet Dataset?

ImageNet is one of the most influential datasets in the history of computer vision. It contains over 14 million labeled images across more than 20,000 categories, with a popular subset of 1,000 categories used in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC). Models trained on ImageNet learn powerful visual features that generalize well to many downstream tasks, making them a popular choice for transfer learning and fine-tuning.

Training from Scratch vs Fine-Tuning Pretrained Models

There are two common approaches to training an image classification model:

  • Training from scratch: The model is initialized with random weights and trained entirely on the target dataset. This requires a large amount of labeled data and computational resources.
  • Fine-tuning a pretrained model: A model pretrained on ImageNet is adapted to a new task. This typically involves replacing the final classification layer and continuing training on the new dataset. Fine-tuning is faster and often yields better performance on small datasets.

How to Download the Kaggle Lions and Cheetahs Dataset

To use the dataset, first ensure you have the Kaggle API installed and configured. Follow these steps:


# 1. Install kaggle if not already installed
pip install kaggle

# 2. Download your API token from https://www.kaggle.com/account
#    and place it as kaggle.json in ~/.kaggle/
mkdir -p ~/.kaggle
cp /path/to/kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json
# 3. Download the dataset using Kaggle CLI
kaggle datasets download -d mikoajfish99/lions-or-cheetahs-image-classification

# 4. Unzip the dataset
unzip lions-or-cheetahs-image-classification.zip -d /path/to/lions_or_cheetahs
# Unziped directory has following structure /lions_or_cheetahs/images/Lions and /lions_or_cheetahs/images/Cheetahs. 
# We need to split the dataset as train, val, and test datasets for training.
# The ratio of the train, val, and test sets are usually 7:1:2.
# As a result, the final structure is 
# lions_or_cheetahs/train/Lions
# lions_or_cheetahs/train/Cheetahs
# lions_or_cheetahs/val/Lions
# lions_or_cheetahs/val/Cheetahs
# lions_or_cheetahs/test/Lions
# lions_or_cheetahs/test/Cheetahs

Popular Backbone Architectures for Image Classification

Backbone architectures are the core building blocks of deep neural networks used in image classification. Some of the most widely used backbones include:

  • ResNet (Residual Networks): Introduced residual connections that allow very deep networks to be trained effectively. Popular variants include ResNet-18, ResNet-34, ResNet-50, and ResNet-101.
  • EfficientNet: Scales depth, width, and resolution in a principled way, achieving high accuracy with fewer parameters.
  • DenseNet: Connects each layer to every other layer in a feed-forward fashion to strengthen feature propagation.
  • Vision Transformers (ViT): Applies transformer architectures to image patches, achieving state-of-the-art results on large datasets.

Step-by-Step: Fine-Tune ResNet-18 on Kaggle Lions vs Cheetahs Dataset (PyTorch)


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5)
    transforms.ToTensor(),          # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load the training and validation datasets
train_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load ResNet-18 pretrained on ImageNet
model = models.resnet18(pretrained=True)
# Replace the last fully connected layer to match number of classes (2)
model.fc = nn.Linear(model.fc.in_features, 2)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
for epoch in range(5):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()           # Reset gradients
        outputs = model(images)         # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()                 # Backward pass
        optimizer.step()                # Update weights
    print(f"Epoch {epoch+1}/{5}, Loss: {loss.item():.4f}")    

# Evaluate on validation set
model.eval()
correct = 0
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        _, preds = torch.max(outputs, 1)       # Get class predictions
        correct += torch.sum(preds == labels).item()  # Count correct predictions
print(f"Validation accuracy: {correct / len(val_dataset):.2%}")

Training Result:

Fine-Tuning ResNet-18 using PyTorch Lightning (Modular & Scalable)

import torch
from torch import nn
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# Define the LightningModule
class ResNet18Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet50(pretrained=True)  # Load pretrained model
        self.model.fc = nn.Linear(self.model.fc.in_features, 2)  # Adjust final layer
        self.criterion = nn.CrossEntropyLoss()          # Loss function

    def forward(self, x):
        return self.model(x)  # Forward pass

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss  # Training loss returned automatically

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc, prog_bar=True)  # Log accuracy

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

# Prepare data loaders
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5)
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.ToTensor(),          # Convert PIL image to tensor
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.ToTensor(),
])

train_ds = ImageFolder('/path/to/lions_or_cheetahs/train', transform=transform)
val_ds = ImageFolder('/path/to/lions_or_cheetahs/val', transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

# Train model
model = ResNet18Classifier()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_loader, val_loader)

Conclusion

Image classification is a cornerstone of computer vision applications. Training from scratch is powerful but resource-intensive, while fine-tuning pretrained models like ResNet-18 on ImageNet provides a more practical and efficient solution, especially for small datasets. Using modern frameworks like PyTorch and Torch Lightning allows engineers to build, train, and scale models more effectively. Evaluation code for both PyTorch and Lightning versions helps verify model performance and make real-world predictions.

References

Comments

Popular

Understanding SentencePiece: A Language-Independent Tokenizer for AI Engineers

In the realm of Natural Language Processing (NLP), tokenization plays a pivotal role in preparing text data for machine learning models. Traditional tokenization methods often rely on language-specific rules and pre-tokenized inputs, which can be limiting when dealing with diverse languages and scripts. Enter SentencePiece—a language-independent tokenizer and detokenizer designed to address these challenges and streamline the preprocessing pipeline for neural text processing systems. What is SentencePiece? SentencePiece is an open-source tokenizer and detokenizer developed by Google, tailored for neural-based text processing tasks such as Neural Machine Translation (NMT). Unlike conventional tokenizers that depend on whitespace and language-specific rules, SentencePiece treats the input text as a raw byte sequence, enabling it to process languages without explicit word boundaries, such as Japanese, Chinese, and Korean. This approach allows SentencePiece to train subword models di...

Mastering the Byte Pair Encoding (BPE) Tokenizer for NLP and LLMs

Byte Pair Encoding (BPE) is one of the most important and widely adopted subword tokenization algorithms in modern Natural Language Processing (NLP), especially in training Large Language Models (LLMs) like GPT. This guide provides a deep technical dive into how BPE works, compares it with other tokenizers like WordPiece and SentencePiece, and explains its practical implementation with Python code. This article is optimized for AI engineers building real-world models and systems. 1. What is Byte Pair Encoding? BPE was originally introduced as a data compression algorithm by Gage in 1994. It replaces the most frequent pair of bytes in a sequence with a single, unused byte. In 2015, Sennrich et al. adapted BPE for NLP to address the out-of-vocabulary (OOV) problem in neural machine translation. Instead of working with full words, BPE decomposes them into subword units that can be recombined to represent rare or unseen words. 2. Why Tokenization Matters in LLMs Tokenization is th...

ZeRO: Deep Memory Optimization for Training Trillion-Parameter Models

In 2020, Microsoft researchers introduced ZeRO (Zero Redundancy Optimizer) via their paper "ZeRO: Memory Optimization Towards Training Trillion Parameter Models" (arXiv:1910.02054). ZeRO is a memory optimization technique that eliminates redundancy in distributed training, enabling efficient scaling to trillion-parameter models. This provides an in-depth technical breakdown of ZeRO's partitioning strategies, memory usage analysis, and integration with DeepSpeed. 1. What is ZeRO? ZeRO eliminates redundant memory copies of model states across GPUs. Instead of replicating parameters, gradients, and optimizer states across each GPU, ZeRO partitions them across all devices. This results in near-linear memory savings as the number of GPUs increases. 2. Limitations of Traditional Data Parallelism In standard data-parallel training, every GPU maintains: Model Parameters $\theta$ Gradients $\nabla \theta$ Optimizer States $O(\theta)$ This causes memory usage ...