Skip to main content

Image Classification: Fine-Tune ResNet18 on Kaggle Dataset (PyTorch + Lightning)

Image Classification: Fine-Tuning ResNet-18 on Kaggle's Lions vs Cheetahs Dataset

Image classification is a fundamental task in computer vision where the goal is to assign a label or class to an input image. It is widely used in various domains such as medical imaging, autonomous driving, wildlife monitoring, and security. A typical image classification pipeline involves feeding an image into a neural network model, which processes the input and outputs class probabilities corresponding to predefined categories.

What is the ImageNet Dataset?

ImageNet is one of the most influential datasets in the history of computer vision. It contains over 14 million labeled images across more than 20,000 categories, with a popular subset of 1,000 categories used in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC). Models trained on ImageNet learn powerful visual features that generalize well to many downstream tasks, making them a popular choice for transfer learning and fine-tuning.

Training from Scratch vs Fine-Tuning Pretrained Models

There are two common approaches to training an image classification model:

  • Training from scratch: The model is initialized with random weights and trained entirely on the target dataset. This requires a large amount of labeled data and computational resources.
  • Fine-tuning a pretrained model: A model pretrained on ImageNet is adapted to a new task. This typically involves replacing the final classification layer and continuing training on the new dataset. Fine-tuning is faster and often yields better performance on small datasets.

How to Download the Kaggle Lions and Cheetahs Dataset

To use the dataset, first ensure you have the Kaggle API installed and configured. Follow these steps:


# 1. Install kaggle if not already installed
pip install kaggle

# 2. Download your API token from https://www.kaggle.com/account
#    and place it as kaggle.json in ~/.kaggle/
mkdir -p ~/.kaggle
cp /path/to/kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json
# 3. Download the dataset using Kaggle CLI
kaggle datasets download -d mikoajfish99/lions-or-cheetahs-image-classification

# 4. Unzip the dataset
unzip lions-or-cheetahs-image-classification.zip -d /path/to/lions_or_cheetahs
# Unziped directory has following structure /lions_or_cheetahs/images/Lions and /lions_or_cheetahs/images/Cheetahs. 
# We need to split the dataset as train, val, and test datasets for training.
# The ratio of the train, val, and test sets are usually 7:1:2.
# As a result, the final structure is 
# lions_or_cheetahs/train/Lions
# lions_or_cheetahs/train/Cheetahs
# lions_or_cheetahs/val/Lions
# lions_or_cheetahs/val/Cheetahs
# lions_or_cheetahs/test/Lions
# lions_or_cheetahs/test/Cheetahs

Popular Backbone Architectures for Image Classification

Backbone architectures are the core building blocks of deep neural networks used in image classification. Some of the most widely used backbones include:

  • ResNet (Residual Networks): Introduced residual connections that allow very deep networks to be trained effectively. Popular variants include ResNet-18, ResNet-34, ResNet-50, and ResNet-101.
  • EfficientNet: Scales depth, width, and resolution in a principled way, achieving high accuracy with fewer parameters.
  • DenseNet: Connects each layer to every other layer in a feed-forward fashion to strengthen feature propagation.
  • Vision Transformers (ViT): Applies transformer architectures to image patches, achieving state-of-the-art results on large datasets.

Step-by-Step: Fine-Tune ResNet-18 on Kaggle Lions vs Cheetahs Dataset (PyTorch)


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5)
    transforms.ToTensor(),          # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load the training and validation datasets
train_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load ResNet-18 pretrained on ImageNet
model = models.resnet18(pretrained=True)
# Replace the last fully connected layer to match number of classes (2)
model.fc = nn.Linear(model.fc.in_features, 2)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
for epoch in range(5):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()           # Reset gradients
        outputs = model(images)         # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()                 # Backward pass
        optimizer.step()                # Update weights
    print(f"Epoch {epoch+1}/{5}, Loss: {loss.item():.4f}")    

# Evaluate on validation set
model.eval()
correct = 0
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        _, preds = torch.max(outputs, 1)       # Get class predictions
        correct += torch.sum(preds == labels).item()  # Count correct predictions
print(f"Validation accuracy: {correct / len(val_dataset):.2%}")

Training Result:

Fine-Tuning ResNet-18 using PyTorch Lightning (Modular & Scalable)

import torch
from torch import nn
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# Define the LightningModule
class ResNet18Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet50(pretrained=True)  # Load pretrained model
        self.model.fc = nn.Linear(self.model.fc.in_features, 2)  # Adjust final layer
        self.criterion = nn.CrossEntropyLoss()          # Loss function

    def forward(self, x):
        return self.model(x)  # Forward pass

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss  # Training loss returned automatically

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc, prog_bar=True)  # Log accuracy

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

# Prepare data loaders
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5)
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.ToTensor(),          # Convert PIL image to tensor
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.ToTensor(),
])

train_ds = ImageFolder('/path/to/lions_or_cheetahs/train', transform=transform)
val_ds = ImageFolder('/path/to/lions_or_cheetahs/val', transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

# Train model
model = ResNet18Classifier()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_loader, val_loader)

Conclusion

Image classification is a cornerstone of computer vision applications. Training from scratch is powerful but resource-intensive, while fine-tuning pretrained models like ResNet-18 on ImageNet provides a more practical and efficient solution, especially for small datasets. Using modern frameworks like PyTorch and Torch Lightning allows engineers to build, train, and scale models more effectively. Evaluation code for both PyTorch and Lightning versions helps verify model performance and make real-world predictions.

References

Comments

Popular

How to Save and Retrieve a Vector Database using LangChain, FAISS, and Gemini Embeddings

How to Save and Retrieve a Vector Database using LangChain, FAISS, and Gemini Embeddings Efficient storage and retrieval of vector databases is foundational for building intelligent retrieval-augmented generation (RAG) systems using large language models (LLMs). In this guide, we’ll walk through a professional-grade Python implementation that utilizes LangChain with FAISS and Google Gemini Embeddings to store document embeddings and retrieve similar information. This setup is highly suitable for advanced machine learning (ML) and deep learning (DL) engineers who work with semantic search and retrieval pipelines. Why Vector Databases Matter in LLM Applications Traditional keyword-based search systems fall short when it comes to understanding semantic meaning. Vector databases store high-dimensional embeddings of text data, allowing for approximate nearest-neighbor (ANN) searches based on semantic similarity. These capabilities are critical in applications like: Question Ans...

Building an MCP Agent with UV, Python & mcp-use

Model Context Protocol (MCP) is an open protocol designed to enable AI agents to interact with external tools and data in a standardized way. MCP is composed of three components: server , client , and host . MCP host The MCP host acts as the interface between the user and the agent   (such as Claude Desktop or IDE) and plays the role of connecting to external tools or data through MCP clients and servers. Previously, Anthropic’s Claude Desktop was introduced as a host, but it required a separate desktop app, license, and API key management, leading to dependency on the Claude ecosystem.   mcp-use is an open-source Python/Node package that connects LangChain LLMs (e.g., GPT-4, Claude, Groq) to MCP servers in just six lines of code, eliminating dependencies and supporting multi-server and multi-model setups. MCP Client The MCP client manages the MCP protocol within the host and is responsible for connecting to MCP servers that provide the necessary functions for the ...

RF-DETR: Overcoming the Limitations of DETR in Object Detection

RF-DETR (Region-Focused DETR), proposed in April 2025, is an advanced object detection architecture designed to overcome fundamental drawbacks of the original DETR (DEtection TRansformer) . In this technical article, we explore RF-DETR's contributions, architecture, and how it compares with both DETR and the improved model D-FINE . We also provide experimental benchmarks and discuss its real-world applicability. RF-DETR Architecture diagram for object detection Limitations of DETR DETR revolutionized object detection by leveraging the Transformer architecture, enabling end-to-end learning without anchor boxes or NMS (Non-Maximum Suppression). However, DETR has notable limitations: Slow convergence, requiring heavy data augmentation and long training schedules Degraded performance on low-resolution objects and complex scenes Lack of locality due to global self-attention mechanisms Key Innovations in RF-DETR RF-DETR intr...