Image Classification with ResNet-18: Advanced Training Strategies and Inference Pipeline
1. Hyperparameter Configuration
The performance of a deep learning model is highly influenced by the choice of hyperparameters. Below are some key hyperparameters that are commonly tuned:
- Learning Rate: Controls the step size during training. Commonly set between 1e-3 and 1e-5.
- Batch Size: Number of images processed in a single iteration. Adjust based on GPU memory.
- Epoch: Number of full passes through the entire training dataset.
- Optimizer: Algorithm used to update model parameters (e.g., Adam, SGD).
- Scheduler: Gradually adjusts the learning rate as training progresses.
# Example hyperparameters
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"
2. Saving the Best and Final Model
During training, the validation accuracy is monitored at the end of each epoch. The model that achieves the highest validation accuracy is saved as the best model, while the final model from the last epoch is also saved separately.
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# Define data transforms for preprocessing
# Feature normalize is required applied at ImageNet training
# For training, random horizontal flip is required for preventing overfitting
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize image to 224x224 (standard for ResNet)
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(), # Convert PIL image to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
MODEL_PATH = "best_model.pt"
# 데이터셋 불러오기
train_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/train', transform=train_transform)
val_dataset = datasets.ImageFolder('./Data/lions_or_cheetahs/val', transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
best_acc = 0.0
for epoch in range(EPOCHS):
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Validation
model.eval()
correct = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
acc = correct / len(val_dataset)
print(f"Epoch {epoch+1}/{EPOCHS}, Val Accuracy: {acc:.4f}")
# Save best model
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), "best_model.pt")
# Save final model
torch.save(model.state_dict(), "last_model.pt")
Training Result:
3. Loading the Trained Model and Implementing Inference
Once training is completed, the saved model can be reloaded for inference. During inference, the model must be set to evaluation mode using model.eval()
.
from PIL import Image
from torchvision import transforms
# Load trained model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_model.pt", map_location="cpu"))
model.eval()
# Define preprocessing for inference
infer_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Inference function
def predict(image_path):
image = Image.open(image_path).convert("RGB")
image = infer_transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
predicted = torch.argmax(output, 1)
return "Lion" if predicted.item() == 1 else "Cheetah"
print(predict("test_cheetah.jpg"))
4. Conclusion
This extended tutorial demonstrated the complete deep learning workflow including hyperparameter tuning, saving the best validation model, and implementing an inference system using a pretrained ResNet-18 in PyTorch. The presented structure can be easily reused for various image classification tasks and is especially useful for building real-time applications or deploying models to production environments.
References
- He et al., "Deep Residual Learning for Image Recognition," 2015. arXiv
- PyTorch Documentation: pytorch.org
- ImageNet Dataset: image-net.org
- Kaggle Lion vs Cheetah Dataset: kaggle.com
- Lightning Documentation (for structured training): lightning.ai
Comments
Post a Comment