Skip to main content

Image Classification with ResNet-18: Training, Validation, and Inference using PyTorch

Image Classification with ResNet-18: Advanced Training Strategies and Inference Pipeline

This article is a follow-up to the previous guide, "Image Classification: Fine-tuning ResNet-18 on Kaggle Dataset (Pytorch + Lightning)". I recommend reviewing the previous post before proceeding.

1. Hyperparameter Configuration

The performance of a deep learning model is highly influenced by the choice of hyperparameters. Below are some key hyperparameters that are commonly tuned:

  • Learning Rate: Controls the step size during training. Commonly set between 1e-3 and 1e-5.
  • Batch Size: Number of images processed in a single iteration. Adjust based on GPU memory.
  • Epoch: Number of full passes through the entire training dataset.
  • Optimizer: Algorithm used to update model parameters (e.g., Adam, SGD).
  • Scheduler: Gradually adjusts the learning rate as training progresses.

# Example hyperparameters
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"

2. Saving the Best and Final Model

During training, the validation accuracy is monitored at the end of each epoch. The model that achieves the highest validation accuracy is saved as the best model, while the final model from the last epoch is also saved separately.


import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),          # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"

# 데이터셋 불러오기
train_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

    acc = correct / len(val_dataset)
    print(f"Epoch {epoch+1}/{EPOCHS}, Val Accuracy: {acc:.4f}")

    # Save best model
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_model.pt")

# Save final model
torch.save(model.state_dict(), "last_model.pt")

Training Result:

3. Loading the Trained Model and Implementing Inference

Once training is completed, the saved model can be reloaded for inference. During inference, the model must be set to evaluation mode using model.eval().


from PIL import Image
from torchvision import transforms

# Load trained model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()

# Define preprocessing for inference
infer_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Inference function
def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    image = infer_transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(image)
        predicted = torch.argmax(output, 1)
    return "Lion" if predicted.item() == 1 else "Cheetah"

print(predict("test_cheetah.jpg"))

4. Conclusion

This extended tutorial demonstrated the complete deep learning workflow including hyperparameter tuning, saving the best validation model, and implementing an inference system using a pretrained ResNet-18 in PyTorch. The presented structure can be easily reused for various image classification tasks and is especially useful for building real-time applications or deploying models to production environments.

References

Comments

Popular

Understanding SentencePiece: A Language-Independent Tokenizer for AI Engineers

In the realm of Natural Language Processing (NLP), tokenization plays a pivotal role in preparing text data for machine learning models. Traditional tokenization methods often rely on language-specific rules and pre-tokenized inputs, which can be limiting when dealing with diverse languages and scripts. Enter SentencePiece—a language-independent tokenizer and detokenizer designed to address these challenges and streamline the preprocessing pipeline for neural text processing systems. What is SentencePiece? SentencePiece is an open-source tokenizer and detokenizer developed by Google, tailored for neural-based text processing tasks such as Neural Machine Translation (NMT). Unlike conventional tokenizers that depend on whitespace and language-specific rules, SentencePiece treats the input text as a raw byte sequence, enabling it to process languages without explicit word boundaries, such as Japanese, Chinese, and Korean. This approach allows SentencePiece to train subword models di...

Mastering the Byte Pair Encoding (BPE) Tokenizer for NLP and LLMs

Byte Pair Encoding (BPE) is one of the most important and widely adopted subword tokenization algorithms in modern Natural Language Processing (NLP), especially in training Large Language Models (LLMs) like GPT. This guide provides a deep technical dive into how BPE works, compares it with other tokenizers like WordPiece and SentencePiece, and explains its practical implementation with Python code. This article is optimized for AI engineers building real-world models and systems. 1. What is Byte Pair Encoding? BPE was originally introduced as a data compression algorithm by Gage in 1994. It replaces the most frequent pair of bytes in a sequence with a single, unused byte. In 2015, Sennrich et al. adapted BPE for NLP to address the out-of-vocabulary (OOV) problem in neural machine translation. Instead of working with full words, BPE decomposes them into subword units that can be recombined to represent rare or unseen words. 2. Why Tokenization Matters in LLMs Tokenization is th...

Using Gemini API in LangChain: Step-by-Step Tutorial

What is LangChain and Why Use It? LangChain  is an open-source framework that simplifies the use of  Large Language Models (LLMs)  like OpenAI, Gemini (Google), and others by adding structure, tools, and memory to help build real-world applications such as chatbots, assistants, agents, or AI-enhanced software. Why Use LangChain for LLM Projects? Chainable Components : Easily build pipelines combining prompts, LLMs, tools, and memory. Multi-Model Support : Work with Gemini, OpenAI, Anthropic, Hugging Face, etc. Built-in Templates : Manage prompts more effectively. Supports Multi-Turn Chat : Manage complex interactions with memory and roles. Tool and API Integration : Let the model interact with external APIs or functions. Let's Walk Through the Code: Gemini + LangChain I will break the code into  4 main parts , each showcasing different features of LangChain and Gemini API. Part 1: Basic Gemini API Call Using LangChain import os from dotenv import load_dotenv load_dot...