-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
168 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# (C) Copyright SaashaJoshi 2024. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,123 @@ | ||
import pytest | ||
""" | ||
Test Suite for the MNIST Data Loader | ||
This module contains a series of unit tests for the MNIST data loader | ||
implemented in the `mnist_data_loader` module. It verifies that the | ||
data loader functions as expected, checking various aspects such as | ||
loading the dataset, batching of images and labels, normalization, | ||
resizing images, and filtering by labels. | ||
The tests include: | ||
- Loading both training and testing datasets. | ||
- Validating the dimensions of images in batches. | ||
- Ensuring proper normalization of image pixel values. | ||
- Resizing images to the nearest power of 2 dimensions. | ||
- Filtering batches based on specified labels. | ||
- Confirming that DataLoader instances are created successfully. | ||
These tests utilize the pytest framework for structured testing and | ||
assertion checking. | ||
""" | ||
|
||
import math | ||
import torch | ||
from piqture.data_loader import load_mnist_dataset | ||
|
||
# Fixture for loading MNIST dataset | ||
@pytest.fixture | ||
def mnist_data(): | ||
return load_mnist_dataset( | ||
img_size=28, | ||
batch_size_train=64, | ||
batch_size_test=1000, | ||
normalize_min=0, | ||
normalize_max=1, | ||
split_ratio=0.8, | ||
load="both" | ||
import torchvision.transforms.functional as F | ||
from piqture.data_loader.mnist_data_loader import load_mnist_dataset | ||
from piqture.transforms import MinMaxNormalization | ||
|
||
|
||
def test_load_mnist_dataset(): | ||
""" | ||
Test that the dataset loads without errors and returns DataLoader objects | ||
""" | ||
train_loader, test_loader = load_mnist_dataset( | ||
batch_size_train=64, batch_size_test=1000, load="both" | ||
) | ||
|
||
# Test that the dataset loads without errors and returns DataLoader objects | ||
def test_mnist_dataloaders(mnist_data): | ||
train_loader, test_loader = mnist_data | ||
# Check that the loaders are instances of DataLoader | ||
assert isinstance( | ||
train_loader, | ||
torch.utils.data.DataLoader | ||
), "Train loader should be a DataLoader" | ||
assert isinstance( | ||
test_loader, | ||
torch.utils.data.DataLoader | ||
), "Test loader should be a DataLoader" | ||
|
||
assert isinstance(train_loader, torch.utils.data.DataLoader), "Train loader should be a DataLoader" | ||
assert isinstance(test_loader, torch.utils.data.DataLoader), "Test loader should be a DataLoader" | ||
|
||
# Test that the DataLoader batches have the correct image and label shape | ||
def test_dataloader_batches(mnist_data): | ||
train_loader, test_loader = mnist_data | ||
|
||
# Get the first batch | ||
for image_batch, label_batch in train_loader: | ||
assert image_batch.shape[0] == 64, "Train batch size should be 64" | ||
assert image_batch.shape[2:] == (28, 28), "Each image should have dimensions 28x28" | ||
assert len(label_batch) == 64, "There should be 64 labels in the batch" | ||
break | ||
def test_dataloader_batches(): | ||
""" | ||
Test that the DataLoader batches have the correct image and label shape | ||
""" | ||
train_loader, test_loader = load_mnist_dataset(batch_size_train=64, batch_size_test=1000, load="both") # pylint: disable=C0301 | ||
|
||
# Get the first batch from test loader | ||
for image_batch, label_batch in test_loader: | ||
assert image_batch.shape[0] == 1000, "Test batch size should be 1000" | ||
assert image_batch.shape[2:] == (28, 28), "Each test image should have dimensions 28x28" | ||
assert len(label_batch) == 1000, "There should be 1000 labels in the test batch" | ||
break | ||
# Test a single batch from the train loader | ||
for images, labels in train_loader: | ||
assert images.shape[0] == 64, "Train batch size should be 64" | ||
assert images.shape[1] == 1, "Each image should have 1 channel" | ||
assert images.shape[2] == 28, "Each image should have height of 28 pixels" | ||
assert images.shape[3] == 28, "Each image should have width of 28 pixels" | ||
assert len(labels) == 64, "There should be 64 labels in the batch" | ||
break # Only need to check one batch | ||
|
||
# Test a single batch from the test loader | ||
for images, labels in test_loader: | ||
assert images.shape[0] == 1000, "Test batch size should be 1000" | ||
assert images.shape[1] == 1, "Each image should have 1 channel" | ||
assert images.shape[2] == 28, "Each image should have height of 28 pixels" | ||
assert images.shape[3] == 28, "Each image should have width of 28 pixels" | ||
assert len(labels) == 1000, "There should be 1000 labels in the batch" | ||
break # Only need to check one batch | ||
|
||
# Test normalization is applied correctly | ||
def test_normalization(mnist_data): | ||
train_loader, _ = mnist_data | ||
|
||
# Get the first batch and check normalization | ||
for image_batch, _ in train_loader: | ||
min_val = image_batch.min().item() | ||
max_val = image_batch.max().item() | ||
|
||
assert 0 <= min_val < 1, "Image pixels should be normalized between 0 and 1 (min value)" | ||
assert 0 < max_val <= 1, "Image pixels should be normalized between 0 and 1 (max value)" | ||
break | ||
|
||
# Test loading only the train set | ||
def test_load_train_only(): | ||
def test_resizing_images(): | ||
""" | ||
Test resizing of images. | ||
""" | ||
train_loader = load_mnist_dataset(load="train", batch_size_train=64) | ||
|
||
assert isinstance(train_loader, torch.utils.data.DataLoader), "Train loader should be a DataLoader" | ||
|
||
for image_batch, label_batch in train_loader: | ||
assert len(image_batch) == 64, "Train batch size should be 64" | ||
break | ||
for images, _ in train_loader: | ||
# Assuming you want to retrieve the first image only | ||
image = images[0] # pylint: disable=E1136 | ||
|
||
# Test loading only the test set | ||
def test_load_test_only(): | ||
test_loader = load_mnist_dataset(load="test", batch_size_test=1000) | ||
assert image.dim() == 3, "Image tensor should have 3 dimensions (C, H, W)" | ||
height, width = image.squeeze().size() | ||
|
||
assert isinstance(test_loader, torch.utils.data.DataLoader), "Test loader should be a DataLoader" | ||
|
||
for image_batch, label_batch in test_loader: | ||
assert len(image_batch) == 1000, "Test batch size should be 1000" | ||
break | ||
# Resize image to the nearest power of 2 | ||
new_height = 2 ** math.ceil(math.log2(height)) | ||
new_width = 2 ** math.ceil(math.log2(width)) | ||
|
||
# Resize image using torchvision's functional transforms | ||
image_resized = F.resize(image, (new_height, new_width)) | ||
|
||
# Check that the resized dimensions are correct | ||
assert image_resized.shape[2] == new_height, f"Image height should be {new_height} after resizing." # pylint: disable=C0301 | ||
break # Only need to check one image | ||
|
||
|
||
def test_normalization_after_resizing(): | ||
""" | ||
Test normalization after resizing. | ||
""" | ||
train_loader = load_mnist_dataset(load="train", batch_size_train=64) | ||
|
||
for images, _ in train_loader: | ||
image = images[0] # pylint: disable=E1136 | ||
|
||
# Resize image to the nearest power of 2 | ||
height, width = image.squeeze().size() | ||
new_height = 2 ** math.ceil(math.log2(height)) | ||
new_width = 2 ** math.ceil(math.log2(width)) | ||
image_resized = F.resize(image, (new_height, new_width)) | ||
|
||
# Test custom label filtering in collate function | ||
def test_label_filtering(): | ||
labels = [0, 1] # Only keep images with labels 0 or 1 | ||
train_loader = load_mnist_dataset(load="train", batch_size_train=64, labels=labels) | ||
# Apply MinMaxNormalization | ||
normalizer = MinMaxNormalization(normalize_min=0, normalize_max=1) | ||
image_normalized = normalizer(image_resized) | ||
|
||
for image_batch, label_batch in train_loader: | ||
assert all(label in labels for label in label_batch), "All labels should be in the specified label list" | ||
# Check normalization | ||
min_val = image_normalized.min().item() | ||
max_val = image_normalized.max().item() | ||
assert 0 <= min_val < 1, "Normalized image pixels should be between 0 and 1 (min value)" | ||
assert 0 < max_val <= 1, "Normalized image pixels should be between 0 and 1 (max value)" | ||
break |