Skip to content

Commit

Permalink
Fixed issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ketayon committed Oct 20, 2024
1 parent 9b70a9f commit 026ce69
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 100 deletions.
85 changes: 52 additions & 33 deletions piqture/data_loader/mnist_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
# (C) Copyright SaashaJoshi 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Data Loader for MNIST images"""

from __future__ import annotations

from functools import partial
from typing import Union, Tuple
from typing import Union, Tuple, List, Optional

import torch.utils.data
import torchvision
from torchvision import datasets

from piqture.transforms import MinMaxNormalization
from piqture.transforms.transforms import MinMaxNormalization


def load_mnist_dataset(
def load_mnist_dataset( # pylint: disable=R0913, R0917, R0914
img_size: Union[int, Tuple[int, int]] = 28,
batch_size_train: int = 64,
batch_size_test: int = 1000,
labels: list = None,
normalize_min: float = None,
normalize_max: float = None,
labels: Optional[List[int]] = None,
normalize_min: Optional[float] = None,
normalize_max: Optional[float] = None,
split_ratio: float = 0.8,
load: str = "both", # Options: "train", "test", or "both"
):
Expand All @@ -34,12 +43,14 @@ def load_mnist_dataset(
normalize_min (float, optional): Minimum value for normalization.
normalize_max (float, optional): Maximum value for normalization.
split_ratio (float, optional): Ratio to split train/test datasets. Defaults to 0.8.
load (str, optional): Indicates whether to load "train", "test", or "both". Defaults to "both".
load (str, optional): Indicates whether to load "train", "test", or "both".
Defaults to "both".
Returns:
Train and/or Test DataLoader objects, depending on the `load` argument.
"""

# Validate inputs
if not isinstance(img_size, (int, tuple)):
raise TypeError("img_size must be an int or tuple[int, int].")

Expand All @@ -49,7 +60,7 @@ def load_mnist_dataset(
if not isinstance(batch_size_train, int) or not isinstance(batch_size_test, int):
raise TypeError("batch_size_train and batch_size_test must be integers.")

if labels and not isinstance(labels, list):
if labels is not None and not isinstance(labels, list):
raise TypeError("labels must be a list.")

if load not in {"train", "test", "both"}:
Expand All @@ -59,7 +70,8 @@ def load_mnist_dataset(
mnist_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(img_size),
MinMaxNormalization(normalize_min, normalize_max) if normalize_min and normalize_max else torchvision.transforms.Lambda(lambda x: x)
MinMaxNormalization(normalize_min, normalize_max) if normalize_min is not None and normalize_max is not None # pylint: disable=C0301
else torchvision.transforms.Lambda(lambda x: x)
])

# Load the full MNIST dataset
Expand All @@ -78,38 +90,45 @@ def load_mnist_dataset(
mnist_full, [train_size, test_size]
)

custom_collate = None
if labels:
custom_collate = partial(collate_fn, labels=labels, new_batch=[])

# Prepare dataloaders
def create_dataloader(dataset, batch_size, collate_fn=None):
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)

train_dataloader = create_dataloader(mnist_train, batch_size_train, custom_collate)
test_dataloader = create_dataloader(mnist_test, batch_size_test, custom_collate)
# Create dataloaders
train_dataloader = create_dataloader(mnist_train, batch_size_train, labels)
test_dataloader = create_dataloader(mnist_test, batch_size_test, labels)

if load == "train":
return train_dataloader
elif load == "test":
if load == "test":
return test_dataloader
else:
return train_dataloader, test_dataloader
return train_dataloader, test_dataloader


def create_dataloader(dataset, batch_size: int, labels: Optional[List[int]]):
"""
Create a DataLoader for the given dataset with optional label filtering.
Args:
dataset: The dataset to load.
batch_size (int): The batch size for the DataLoader.
labels (list, optional): List of labels to filter by.
def collate_fn(batch, labels: list, new_batch: list):
Returns:
DataLoader: A DataLoader for the dataset.
"""
custom_collate = partial(collate_fn, labels=labels) if labels else None
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=custom_collate
)


def collate_fn(batch, labels: list):
"""
Custom collate function that filters batches by provided labels.
"""
new_batch = []
for img, label in batch:
if label in labels:
new_batch.append((img, label))

if new_batch:
return torch.utils.data.default_collate(new_batch)
return [] # Return empty batch if no matching labels

return torch.utils.data.default_collate(new_batch) if new_batch else []
9 changes: 9 additions & 0 deletions tests/data_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# (C) Copyright SaashaJoshi 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
174 changes: 107 additions & 67 deletions tests/data_loader/test_mnist_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,123 @@
import pytest
"""
Test Suite for the MNIST Data Loader
This module contains a series of unit tests for the MNIST data loader
implemented in the `mnist_data_loader` module. It verifies that the
data loader functions as expected, checking various aspects such as
loading the dataset, batching of images and labels, normalization,
resizing images, and filtering by labels.
The tests include:
- Loading both training and testing datasets.
- Validating the dimensions of images in batches.
- Ensuring proper normalization of image pixel values.
- Resizing images to the nearest power of 2 dimensions.
- Filtering batches based on specified labels.
- Confirming that DataLoader instances are created successfully.
These tests utilize the pytest framework for structured testing and
assertion checking.
"""

import math
import torch
from piqture.data_loader import load_mnist_dataset

# Fixture for loading MNIST dataset
@pytest.fixture
def mnist_data():
return load_mnist_dataset(
img_size=28,
batch_size_train=64,
batch_size_test=1000,
normalize_min=0,
normalize_max=1,
split_ratio=0.8,
load="both"
import torchvision.transforms.functional as F
from piqture.data_loader.mnist_data_loader import load_mnist_dataset
from piqture.transforms import MinMaxNormalization


def test_load_mnist_dataset():
"""
Test that the dataset loads without errors and returns DataLoader objects
"""
train_loader, test_loader = load_mnist_dataset(
batch_size_train=64, batch_size_test=1000, load="both"
)

# Test that the dataset loads without errors and returns DataLoader objects
def test_mnist_dataloaders(mnist_data):
train_loader, test_loader = mnist_data
# Check that the loaders are instances of DataLoader
assert isinstance(
train_loader,
torch.utils.data.DataLoader
), "Train loader should be a DataLoader"
assert isinstance(
test_loader,
torch.utils.data.DataLoader
), "Test loader should be a DataLoader"

assert isinstance(train_loader, torch.utils.data.DataLoader), "Train loader should be a DataLoader"
assert isinstance(test_loader, torch.utils.data.DataLoader), "Test loader should be a DataLoader"

# Test that the DataLoader batches have the correct image and label shape
def test_dataloader_batches(mnist_data):
train_loader, test_loader = mnist_data

# Get the first batch
for image_batch, label_batch in train_loader:
assert image_batch.shape[0] == 64, "Train batch size should be 64"
assert image_batch.shape[2:] == (28, 28), "Each image should have dimensions 28x28"
assert len(label_batch) == 64, "There should be 64 labels in the batch"
break
def test_dataloader_batches():
"""
Test that the DataLoader batches have the correct image and label shape
"""
train_loader, test_loader = load_mnist_dataset(batch_size_train=64, batch_size_test=1000, load="both") # pylint: disable=C0301

# Get the first batch from test loader
for image_batch, label_batch in test_loader:
assert image_batch.shape[0] == 1000, "Test batch size should be 1000"
assert image_batch.shape[2:] == (28, 28), "Each test image should have dimensions 28x28"
assert len(label_batch) == 1000, "There should be 1000 labels in the test batch"
break
# Test a single batch from the train loader
for images, labels in train_loader:
assert images.shape[0] == 64, "Train batch size should be 64"
assert images.shape[1] == 1, "Each image should have 1 channel"
assert images.shape[2] == 28, "Each image should have height of 28 pixels"
assert images.shape[3] == 28, "Each image should have width of 28 pixels"
assert len(labels) == 64, "There should be 64 labels in the batch"
break # Only need to check one batch

# Test a single batch from the test loader
for images, labels in test_loader:
assert images.shape[0] == 1000, "Test batch size should be 1000"
assert images.shape[1] == 1, "Each image should have 1 channel"
assert images.shape[2] == 28, "Each image should have height of 28 pixels"
assert images.shape[3] == 28, "Each image should have width of 28 pixels"
assert len(labels) == 1000, "There should be 1000 labels in the batch"
break # Only need to check one batch

# Test normalization is applied correctly
def test_normalization(mnist_data):
train_loader, _ = mnist_data

# Get the first batch and check normalization
for image_batch, _ in train_loader:
min_val = image_batch.min().item()
max_val = image_batch.max().item()

assert 0 <= min_val < 1, "Image pixels should be normalized between 0 and 1 (min value)"
assert 0 < max_val <= 1, "Image pixels should be normalized between 0 and 1 (max value)"
break

# Test loading only the train set
def test_load_train_only():
def test_resizing_images():
"""
Test resizing of images.
"""
train_loader = load_mnist_dataset(load="train", batch_size_train=64)

assert isinstance(train_loader, torch.utils.data.DataLoader), "Train loader should be a DataLoader"

for image_batch, label_batch in train_loader:
assert len(image_batch) == 64, "Train batch size should be 64"
break
for images, _ in train_loader:
# Assuming you want to retrieve the first image only
image = images[0] # pylint: disable=E1136

# Test loading only the test set
def test_load_test_only():
test_loader = load_mnist_dataset(load="test", batch_size_test=1000)
assert image.dim() == 3, "Image tensor should have 3 dimensions (C, H, W)"
height, width = image.squeeze().size()

assert isinstance(test_loader, torch.utils.data.DataLoader), "Test loader should be a DataLoader"

for image_batch, label_batch in test_loader:
assert len(image_batch) == 1000, "Test batch size should be 1000"
break
# Resize image to the nearest power of 2
new_height = 2 ** math.ceil(math.log2(height))
new_width = 2 ** math.ceil(math.log2(width))

# Resize image using torchvision's functional transforms
image_resized = F.resize(image, (new_height, new_width))

# Check that the resized dimensions are correct
assert image_resized.shape[2] == new_height, f"Image height should be {new_height} after resizing." # pylint: disable=C0301
break # Only need to check one image


def test_normalization_after_resizing():
"""
Test normalization after resizing.
"""
train_loader = load_mnist_dataset(load="train", batch_size_train=64)

for images, _ in train_loader:
image = images[0] # pylint: disable=E1136

# Resize image to the nearest power of 2
height, width = image.squeeze().size()
new_height = 2 ** math.ceil(math.log2(height))
new_width = 2 ** math.ceil(math.log2(width))
image_resized = F.resize(image, (new_height, new_width))

# Test custom label filtering in collate function
def test_label_filtering():
labels = [0, 1] # Only keep images with labels 0 or 1
train_loader = load_mnist_dataset(load="train", batch_size_train=64, labels=labels)
# Apply MinMaxNormalization
normalizer = MinMaxNormalization(normalize_min=0, normalize_max=1)
image_normalized = normalizer(image_resized)

for image_batch, label_batch in train_loader:
assert all(label in labels for label in label_batch), "All labels should be in the specified label list"
# Check normalization
min_val = image_normalized.min().item()
max_val = image_normalized.max().item()
assert 0 <= min_val < 1, "Normalized image pixels should be between 0 and 1 (min value)"
assert 0 < max_val <= 1, "Normalized image pixels should be between 0 and 1 (max value)"
break

0 comments on commit 026ce69

Please sign in to comment.