October 22, 2024
Chicago 12, Melborne City, USA
python

Vision Transformer Model not generalizing well on independent validation dataset


I am training a Wave Vision Transformer model. The code for the Wave_ViT is available on the below link.

https://github.com/YehLi/ImageNetModel/blob/main/classification/wavevit.py

I did not change the code for wave_ViT.py and torch_wavelets.py file. The only change I made is in the pipeline of how to provide data to model. My original dataset involves around 38000 MRI images of 256 with RGB format. I augmented this dataset by rotating each image from its original angle to 90, 180, 270 degrees and saved those images. So each image has its 3 rotated copies. Hence, My original dataset increased to about 156000 images of same size and format.

I further saved those images with labels in a numpy.memmap format of uint8 as my code was giving me OOM error when tried to directly load them in an numpy array at once.

I load my memmap in train and test images with labels like this.

def load_memmap_data( train_memmap_file, train_label_memmap_file, test_memmap_file, test_label_memmap_file,num_train_images,     num_test_images):
        train_images = np.memmap(train_memmap_file, dtype="uint8", mode="r", shape=(num_train_images, 256, 256, 3))
        train_labels = np.memmap(train_label_memmap_file, dtype="int32", mode="r", shape=(num_train_images,))
    
        test_images = np.memmap(test_memmap_file, dtype="uint8", mode="r", shape=(num_test_images, 256, 256, 3))
        test_labels = np.memmap(test_label_memmap_file, dtype="int32", mode="r", shape=(num_test_images,))


        return train_images, train_labels, test_images, test_labels


# Create memory-mapped files for train/test datasets
train_memmap_file="train_images.dat"
train_label_memmap_file="train_labels.dat"
test_memmap_file="test_images.dat"
test_label_memmap_file="test_labels.dat"


train_images, train_labels, test_images, test_labels = load_memmap_data( 
    train_memmap_file=train_memmap_file, 
    train_label_memmap_file=train_label_memmap_file, 
    test_memmap_file=test_memmap_file, 
    test_label_memmap_file=test_label_memmap_file,
    num_train_images=num_train_images,
    num_test_images=num_test_images
    )

My optimizer and call to train function in Trainer class looks like this.

model = WaveViT()
model = nn.DataParallel(model)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()
trainer = Trainer(model, optimizer, loss_fn, exp_name="waveViT-256-aug", device=device)
trainer.train(train_images, train_labels, test_images, test_labels, epochs=100, config=None, steps_per_epoch=steps_per_epoch, augment=False)

My Trainer class looks like this. This take the images and labels, augment them to the particular transformation for the epoch, and train and test it on Model.

class Trainer:
def __init__(self, model, optimizer, loss_fn, exp_name, device):        
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device
        def train(self, train_images, train_labels, test_images, test_labels, epochs, config = None,steps_per_epoch = 0,augment = False):
        train_losses, test_losses, test_accuracies,train_accuracies,test_precision,train_precision,test_recall,train_recall,test_f1,train_f1 = [], [], [],[],[],[],[],[],[],[]
        best_test_loss = float('inf')  # Initialize with a large value
        best_accuracy = 0.0  # Initialize with the worst possible accuracy
        scaler = GradScaler()
        # Early stopping variables
        best_epoch = 0
        epochs_no_improvement = 0  # Counter for epochs without improvement
        # Train the model
        transform_1 = transforms.Compose([
      # Rotate the image by 0-40 degrees
       transforms.RandomAffine(degrees=(-40,40), shear=15),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),])
        
        transform_2 = transforms.Compose([
 # Shear with a 20-degree angle
        transforms.RandomResizedCrop(size=224, scale=(0.95, 1.0)),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)),
        transforms.RandomApply([transforms.ElasticTransform(alpha=30.0)], p=0.3),transforms.ToTensor(),])

        transform_3 = transforms.Compose([transforms.ToTensor(),])
        for i in range(epochs):
            print("\nTraining epoch\n")
            print("Preparing data loaders...")
            if i % 2 == 0:
                transform = transform_1
            elif i % 2 == 0 and i % 7 == 0:
                transform = transform_3# If divisible by 2
            else:
                transform = transform_2  # Otherwise
            trainloader, testloader = prepare_data(
            
            batch_size=64,
            x_train=train_images,
            y_train=train_labels,
            x_test=test_images,
            y_test=test_labels,
            transform=transform
            )
           
            accuracy_train,train_loss,precision_train,recall_train,f1_train = self.train_epoch(trainloader,steps_per_epoch,augment,scaler)
            
            accuracy_test, test_loss,precision_test,recall_test,f1_test = self.evaluate(testloader)
            print("\nEvaluation Completed\n")
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            test_accuracies.append(accuracy_test)
            train_accuracies.append(accuracy_train)
            test_precision.append(precision_test)
            train_precision.append(precision_train)
            test_recall.append(recall_test)
            train_recall.append(recall_train)
            test_f1.append(f1_test)
            train_f1.append(f1_train)
            is_best_loss = test_loss < best_test_loss
            is_best_accuracy = accuracy_test > best_accuracy
            if is_best_loss:
                best_test_loss = test_loss
                best_epoch = i + 1
                epochs_no_improvement = 0  # Reset counter
                save_checkpoint(self.exp_name + "-Best-Test-Loss", self.model, best_epoch)
            else:
                epochs_no_improvement += 1
                  
            if is_best_accuracy:# Update best test loss
                best_accuracy = max(accuracy_test, best_accuracy)  # Update best accuracy
                save_checkpoint(self.exp_name + "-Best-Test-Accuracy", self.model, i+1)
            
            if epochs_no_improvement >= 10:
                print(f"Early stopping triggered after {i + 1} epochs without improvement.")
                break  # Stop training if no improvement
   
        save_experiment(self.exp_name, config, self.model, train_losses, test_losses, test_accuracies,train_accuracies,test_precision,train_precision,test_recall,train_recall,test_f1,train_f1)
        plot_metrics(train_losses, test_losses, train_accuracies, test_accuracies,
                 train_precision, test_precision, train_recall, test_recall,
                 train_f1, test_f1, self.exp_name)

            
    def train_epoch(self, trainloader, steps_per_epoch, augment,scaler):
        self.model.train()
        total_loss = 0
        trainloader_iter = itertools.cycle(trainloader)
        correct = 0
        y_true = []
        y_pred = []      # To store all true labels
    # Wrap the range with tqdm for the progress bar
        with tqdm(total=steps_per_epoch, desc="Training", unit="step") as pbar:
            for i in range(steps_per_epoch):
                batch = next(trainloader_iter)
                batch = [t.to(self.device) for t in batch]
                images, labels = batch
                images, labels = images.to(self.device), labels.to(self.device)
                
                images = images.to(torch.float32)
                with autocast():
                    result = self.model(images)
                    loss = self.loss_fn(result, labels)
                self.optimizer.zero_grad()
                
                scaler.scale(loss).backward()
            # Update the model's parameters
                scaler.step(self.optimizer)
                scaler.update()
                
                total_loss += loss.item() * len(images)
                predictions = torch.argmax(result, dim=1)
                y_pred.extend(predictions.cpu().numpy())
                y_true.extend(labels.cpu().numpy())
                correct += torch.sum(predictions == labels).item()
                # Update the progress bar only after 25% of the progress is done
                if (i + 1) % (steps_per_epoch // 4) == 0:  # 25% of total steps
                    pbar.update(1)
            
        # Convert lists to tensors for calculation
        y_true_tensor = torch.tensor(y_true)
        y_pred_tensor = torch.tensor(y_pred)

# Calculating precision, recall, and F1 score using PyTorch
        TP = ((y_pred_tensor == 1) & (y_true_tensor == 1)).sum().item()
        FP = ((y_pred_tensor == 1) & (y_true_tensor == 0)).sum().item()
        FN = ((y_pred_tensor == 0) & (y_true_tensor == 1)).sum().item()

        precision = TP / (TP + FP) if TP + FP > 0 else 0
        recall = TP / (TP + FN) if TP + FN > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0    
        avg_loss = total_loss / len(trainloader.dataset)
        accuracy = correct / len(trainloader.dataset)  # Accuracy in percentage

        return accuracy, avg_loss,precision,recall,f1

    u/torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        y_true = []
        y_pred = []
        with torch.no_grad():
            for batch in testloader:
                # Move the batch to the device
                batch = [t.to(self.device) for t in batch]
                images, labels = batch
                images = images.to(torch.float32)
                with autocast():
                    result = self.model(images)
                    loss = self.loss_fn(result, labels)
                
                total_loss += loss.item() * len(images)
                predictions = torch.argmax(result, dim=1)
                y_pred.extend(predictions.cpu().numpy())
                y_true.extend(labels.cpu().numpy())
                correct += torch.sum(predictions == labels).item()
        # Convert lists to tensors for calculation
        y_true_tensor = torch.tensor(y_true)
        y_pred_tensor = torch.tensor(y_pred)

# Calculating precision, recall, and F1 score using PyTorch
        TP = ((y_pred_tensor == 1) & (y_true_tensor == 1)).sum().item()
        FP = ((y_pred_tensor == 1) & (y_true_tensor == 0)).sum().item()
        FN = ((y_pred_tensor == 0) & (y_true_tensor == 1)).sum().item()

        precision = TP / (TP + FP) if TP + FP > 0 else 0
        recall = TP / (TP + FN) if TP + FN > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss,precision,recall,f1

The model is doing very well while training and testing on same dataset the execution.

Epoch: 1 
Training Metrics: Accuracy: 0.7357, Loss: 0.5236, Precision: 0.6928, Recall: 0.8479, F1 Score: 0.7625 
Testing Metrics: Accuracy: 0.7672, Loss: 0.4838, Precision: 0.7271, Recall: 0.8556, F1 Score: 0.7861
....
Epoch: 4 
Training Metrics: Accuracy: 0.8031, Loss: 0.4078, Precision: 0.7644, Recall: 0.8772, F1 Score: 0.8169 
Testing Metrics: Accuracy: 0.7494, Loss: 0.4712, Precision: 0.8186, Recall: 0.6408, F1 Score: 0.7189
...
Epoch: 8 
Training Metrics: Accuracy: 0.8529, Loss: 0.3148, Precision: 0.8324, Recall: 0.8845, F1 Score: 0.8577 
Testing Metrics: Accuracy: 0.8280, Loss: 0.4027, Precision: 0.8015, Recall: 0.8720, F1 Score: 0.8352
...
Epoch: 18 
Training Metrics: Accuracy: 0.9284, Loss: 0.1706, Precision: 0.9237, Recall: 0.9346, F1 Score: 0.9292 
Testing Metrics: Accuracy: 0.8008, Loss: 0.5767, Precision: 0.8357, Recall: 0.7488, F1 Score: 0.7899
The model does not over give me the best accuracy and loss at epoch. I have saved the mode at epoch 8 and then keep its accuracy and loss between 79-80.

This model when validated on independent dataset performed poorly.

Validation Metrics:
 - Accuracy: 0.4890
 - Precision: 0.4878
 - Recall: 0.4416
 - F1 Score: 0.4636
 - Confusion Matrix:
[[1341 1159]
 [1396 1104]]

I have also validated it on same dataset as of training and still the accuracy stays same ( even though I gave it same images which I used in training). I have used pretrained weights of the ImageNet ( the original WaveViT was trained on ImageNet and the saved model is present on GitHub) for this WaveViT too, but the result is same.

Please, it will be a great help if someone can help me in resolving this behavior of the model.

Why the validation accuracy even on the same dataset used for training and testing did not improve?

I hope I have explained everything. Please let me know if you need more clarifications.

Thanks



You need to sign in to view this answers

Leave feedback about this

  • Quality
  • Price
  • Service

PROS

+
Add Field

CONS

+
Add Field
Choose Image
Choose Video