Skip to main content

Image Classification with ResNet-18: Training, Validation, and Inference using PyTorch

Image Classification with ResNet-18: Advanced Training Strategies and Inference Pipeline

This article is a follow-up to the previous guide, "Image Classification: Fine-tuning ResNet-18 on Kaggle Dataset (Pytorch + Lightning)". I recommend reviewing the previous post before proceeding.

1. Hyperparameter Configuration

The performance of a deep learning model is highly influenced by the choice of hyperparameters. Below are some key hyperparameters that are commonly tuned:

  • Learning Rate: Controls the step size during training. Commonly set between 1e-3 and 1e-5.
  • Batch Size: Number of images processed in a single iteration. Adjust based on GPU memory.
  • Epoch: Number of full passes through the entire training dataset.
  • Optimizer: Algorithm used to update model parameters (e.g., Adam, SGD).
  • Scheduler: Gradually adjusts the learning rate as training progresses.

# Example hyperparameters
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"

2. Saving the Best and Final Model

During training, the validation accuracy is monitored at the end of each epoch. The model that achieves the highest validation accuracy is saved as the best model, while the final model from the last epoch is also saved separately.


import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224x224 (standard for ResNet)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),          # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"

# 데이터셋 불러오기
train_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

best_acc = 0.0
for epoch in range(EPOCHS):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

    acc = correct / len(val_dataset)
    print(f"Epoch {epoch+1}/{EPOCHS}, Val Accuracy: {acc:.4f}")

    # Save best model
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_model.pt")

# Save final model
torch.save(model.state_dict(), "last_model.pt")

Training Result:

3. Loading the Trained Model and Implementing Inference

Once training is completed, the saved model can be reloaded for inference. During inference, the model must be set to evaluation mode using model.eval().


from PIL import Image
from torchvision import transforms

# Load trained model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()

# Define preprocessing for inference
infer_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Inference function
def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    image = infer_transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(image)
        predicted = torch.argmax(output, 1)
    return "Lion" if predicted.item() == 1 else "Cheetah"

print(predict("test_cheetah.jpg"))

4. Conclusion

This extended tutorial demonstrated the complete deep learning workflow including hyperparameter tuning, saving the best validation model, and implementing an inference system using a pretrained ResNet-18 in PyTorch. The presented structure can be easily reused for various image classification tasks and is especially useful for building real-time applications or deploying models to production environments.

References

Comments

Popular

Building an MCP Agent with UV, Python & mcp-use

Model Context Protocol (MCP) is an open protocol designed to enable AI agents to interact with external tools and data in a standardized way. MCP is composed of three components: server , client , and host . MCP host The MCP host acts as the interface between the user and the agent   (such as Claude Desktop or IDE) and plays the role of connecting to external tools or data through MCP clients and servers. Previously, Anthropic’s Claude Desktop was introduced as a host, but it required a separate desktop app, license, and API key management, leading to dependency on the Claude ecosystem.   mcp-use is an open-source Python/Node package that connects LangChain LLMs (e.g., GPT-4, Claude, Groq) to MCP servers in just six lines of code, eliminating dependencies and supporting multi-server and multi-model setups. MCP Client The MCP client manages the MCP protocol within the host and is responsible for connecting to MCP servers that provide the necessary functions for the ...

How to Save and Retrieve a Vector Database using LangChain, FAISS, and Gemini Embeddings

How to Save and Retrieve a Vector Database using LangChain, FAISS, and Gemini Embeddings Efficient storage and retrieval of vector databases is foundational for building intelligent retrieval-augmented generation (RAG) systems using large language models (LLMs). In this guide, we’ll walk through a professional-grade Python implementation that utilizes LangChain with FAISS and Google Gemini Embeddings to store document embeddings and retrieve similar information. This setup is highly suitable for advanced machine learning (ML) and deep learning (DL) engineers who work with semantic search and retrieval pipelines. Why Vector Databases Matter in LLM Applications Traditional keyword-based search systems fall short when it comes to understanding semantic meaning. Vector databases store high-dimensional embeddings of text data, allowing for approximate nearest-neighbor (ANN) searches based on semantic similarity. These capabilities are critical in applications like: Question Ans...

RF-DETR: Overcoming the Limitations of DETR in Object Detection

RF-DETR (Region-Focused DETR), proposed in April 2025, is an advanced object detection architecture designed to overcome fundamental drawbacks of the original DETR (DEtection TRansformer) . In this technical article, we explore RF-DETR's contributions, architecture, and how it compares with both DETR and the improved model D-FINE . We also provide experimental benchmarks and discuss its real-world applicability. RF-DETR Architecture diagram for object detection Limitations of DETR DETR revolutionized object detection by leveraging the Transformer architecture, enabling end-to-end learning without anchor boxes or NMS (Non-Maximum Suppression). However, DETR has notable limitations: Slow convergence, requiring heavy data augmentation and long training schedules Degraded performance on low-resolution objects and complex scenes Lack of locality due to global self-attention mechanisms Key Innovations in RF-DETR RF-DETR intr...