Image Classification: Fine-Tuning ResNet-18 on Kaggle's Lions vs Cheetahs Dataset
Image classification is a fundamental task in computer vision where the goal is to assign a label or class to an input image. It is widely used in various domains such as medical imaging, autonomous driving, wildlife monitoring, and security. A typical image classification pipeline involves feeding an image into a neural network model, which processes the input and outputs class probabilities corresponding to predefined categories.
What is the ImageNet Dataset?
ImageNet is one of the most influential datasets in the history of computer vision. It contains over 14 million labeled images across more than 20,000 categories, with a popular subset of 1,000 categories used in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC). Models trained on ImageNet learn powerful visual features that generalize well to many downstream tasks, making them a popular choice for transfer learning and fine-tuning.
Training from Scratch vs Fine-Tuning Pretrained Models
There are two common approaches to training an image classification model:
- Training from scratch: The model is initialized with random weights and trained entirely on the target dataset. This requires a large amount of labeled data and computational resources.
- Fine-tuning a pretrained model: A model pretrained on ImageNet is adapted to a new task. This typically involves replacing the final classification layer and continuing training on the new dataset. Fine-tuning is faster and often yields better performance on small datasets.
How to Download the Kaggle Lions and Cheetahs Dataset
To use the dataset, first ensure you have the Kaggle API installed and configured. Follow these steps:
# 1. Install kaggle if not already installed
pip install kaggle
# 2. Download your API token from https://www.kaggle.com/account
# and place it as kaggle.json in ~/.kaggle/
mkdir -p ~/.kaggle
cp /path/to/kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json
# 3. Download the dataset using Kaggle CLI
kaggle datasets download -d mikoajfish99/lions-or-cheetahs-image-classification
# 4. Unzip the dataset
unzip lions-or-cheetahs-image-classification.zip -d /path/to/lions_or_cheetahs
# Unziped directory has following structure /lions_or_cheetahs/images/Lions and /lions_or_cheetahs/images/Cheetahs.
# We need to split the dataset as train, val, and test datasets for training.
# The ratio of the train, val, and test sets are usually 7:1:2.
# As a result, the final structure is
# lions_or_cheetahs/train/Lions
# lions_or_cheetahs/train/Cheetahs
# lions_or_cheetahs/val/Lions
# lions_or_cheetahs/val/Cheetahs
# lions_or_cheetahs/test/Lions
# lions_or_cheetahs/test/Cheetahs
Popular Backbone Architectures for Image Classification
Backbone architectures are the core building blocks of deep neural networks used in image classification. Some of the most widely used backbones include:
- ResNet (Residual Networks): Introduced residual connections that allow very deep networks to be trained effectively. Popular variants include ResNet-18, ResNet-34, ResNet-50, and ResNet-101.
- EfficientNet: Scales depth, width, and resolution in a principled way, achieving high accuracy with fewer parameters.
- DenseNet: Connects each layer to every other layer in a feed-forward fashion to strengthen feature propagation.
- Vision Transformers (ViT): Applies transformer architectures to image patches, achieving state-of-the-art results on large datasets.
Step-by-Step: Fine-Tune ResNet-18 on Kaggle Lions vs Cheetahs Dataset (PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize image to 224x224 (standard for ResNet)
transforms.RandomHorizontalFlip(p=0.5)
transforms.ToTensor(), # Convert PIL image to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Load the training and validation datasets
train_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('/path/to/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Load ResNet-18 pretrained on ImageNet
model = models.resnet18(pretrained=True)
# Replace the last fully connected layer to match number of classes (2)
model.fc = nn.Linear(model.fc.in_features, 2)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
for epoch in range(5):
model.train()
for images, labels in train_loader:
optimizer.zero_grad() # Reset gradients
outputs = model(images) # Forward pass
loss = criterion(outputs, labels) # Compute loss
loss.backward() # Backward pass
optimizer.step() # Update weights
print(f"Epoch {epoch+1}/{5}, Loss: {loss.item():.4f}")
# Evaluate on validation set
model.eval()
correct = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
_, preds = torch.max(outputs, 1) # Get class predictions
correct += torch.sum(preds == labels).item() # Count correct predictions
print(f"Validation accuracy: {correct / len(val_dataset):.2%}")
Training Result:
Fine-Tuning ResNet-18 using PyTorch Lightning (Modular & Scalable)
import torch
from torch import nn
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl
# Define the LightningModule
class ResNet18Classifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.resnet50(pretrained=True) # Load pretrained model
self.model.fc = nn.Linear(self.model.fc.in_features, 2) # Adjust final layer
self.criterion = nn.CrossEntropyLoss() # Loss function
def forward(self, x):
return self.model(x) # Forward pass
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
return loss # Training loss returned automatically
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("val_acc", acc, prog_bar=True) # Log accuracy
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0001)
# Prepare data loaders
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize image to 224x224 (standard for ResNet)
transforms.RandomHorizontalFlip(p=0.5)
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.ToTensor(), # Convert PIL image to tensor
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.ToTensor(),
])
train_ds = ImageFolder('/path/to/lions_or_cheetahs/train', transform=transform)
val_ds = ImageFolder('/path/to/lions_or_cheetahs/val', transform=transform)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)
# Train model
model = ResNet18Classifier()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_loader, val_loader)
Conclusion
Image classification is a cornerstone of computer vision applications. Training from scratch is powerful but resource-intensive, while fine-tuning pretrained models like ResNet-18 on ImageNet provides a more practical and efficient solution, especially for small datasets. Using modern frameworks like PyTorch and Torch Lightning allows engineers to build, train, and scale models more effectively. Evaluation code for both PyTorch and Lightning versions helps verify model performance and make real-world predictions.
References
- He et al., "Deep Residual Learning for Image Recognition," 2015. arXiv:1512.03385
- ImageNet Dataset: www.image-net.org
- PyTorch Documentation: pytorch.org
- Lightning Documentation: lightning.ai
- Kaggle Lions and Cheetahs Dataset: kaggle.com
Comments
Post a Comment