Running Alexnet on VOC2012 using pytorch for object detection
After scouring the internet for a couple of days looking for a code to run alexnet on VOC12 data for classification ending in vain, I had to write the integrated code myself and I am sharing the same for anyone to reuse.
Due to lack of time I am posting the codes for sorting & training separately. I will make an update when I get time, to write a single unified code. Meanwhile if anyone faces any error, just drop a comment (or email to anurag@anuragmeena.com) with details (including pytorch and torchvision versions) and I will look into it.
I had a choice to write a custom dataloader or convert the data into a pre-defined data loader structure. I chose the latter and the image_sort.py code basically converts the VOC12 data into ImageFolder data structure. Just change the root directory string in image_sort.py and run it.
This would create the following folders with hierarchy:
|classes
|train (training set)
|<folders with class names>
|val (validation set)
|<folders with class names>
After this go to main.py. It contains the actual code to initialize your neural network and configure your net (whether you want to download pre-trained model, what is your input, how many output classes you have).
Download Links
image_sort.py
## Place this code in the same folder where Imagefolder exists
# creating data folder
import shutil
from shutil import copyfile
import os
all_files = os.listdir("Main/")
a = 0
all_files_set = set(all_files)
print("_________________________________________________________________")
# To create all class directories
for name in all_files_set:
temp = name.split('_')
if (os.path.isdir("classes_val/" + temp[0])):
print("Directory already exists \n \n \n")
else:
os.makedirs("classes_val/" + temp[0])
if (temp[0] == "train.txt" or temp[0] == "val.txt" or temp[0] == "trainval.txt"):
print("__")
for name in all_files_set:
temp_f = name.split('_')
if (temp_f[0] == "train.txt" or temp_f[0] == "val.txt" or temp_f[0] == "trainval.txt"):
print("\n ************** \n skipped train, val & trainval text files \n ************** \n ")
else:
if (temp_f[1] == "val.txt"):
print("Opening Text file: " + name)
f = open("Main/" + name, "r")
lines = list(f)
f.close()
for line in lines:
temp = line.split(" ")
print(temp)
if (temp[1] == '-1\n'):
print("skipfile")
elif (temp[2] == '1\n'):
print(temp[0] + ".jpg")
shutil.copy("../JPEGImages/" + temp[0] + ".jpg", "classes_val/" + temp_f[0] + "/")
main.py
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
def run():
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
data_dir = "./classes/"
model_name = "alexnet"
num_classes = 20
batch_size = 8
num_epochs = 1
feature_extract = True
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
since = time.time()
val_acc_history = []
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
if is_inception and phase == 'train':
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4 * loss2
else:
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if phase == 'val':
val_acc_history.append(epoch_acc)
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), "./classes/abc.pth")
return model, val_acc_history
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
model_ft = None
input_size = 0
if model_name == "resnet":
model_ft = models.resnet18(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "alexnet":
model_ft = models.alexnet(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "vgg":
model_ft = models.vgg11_bn(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "squeezenet":
model_ft = models.squeezenet1_0(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
model_ft.num_classes = num_classes
input_size = 224
elif model_name == "densenet":
model_ft = models.densenet121(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "inception":
model_ft = models.inception_v3(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 299
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
print(model_ft)
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
print("Initializing Datasets and Dataloaders...")
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
params_to_update = []
for name, param in model_ft.named_parameters():
if param.requires_grad:
params_to_update.append(param)
print("\t", name)
else:
for name, param in model_ft.named_parameters():
if param.requires_grad:
print("\t", name)
optimizer_ft = optim.SGD(params_to_update, lr=0.9, momentum=0.9)
criterion = nn.CrossEntropyLoss()
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name == "inception"))
print("Checkpoint save started")
print("Checkpoint saved")
if __name__ == '__main__':
run()