import os
from typing import Iterable
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import pil_loader
[docs]
class MultiLabeledImagesDataset(Dataset):
"""
A PyTorch dataset class for multi-labeled image data.
Attributes
----------
img_labels : pd.DataFrame
A pandas DataFrame containing the image annotations.
img_dir : str
The directory path containing the images.
transform : torchvision.transforms
A transform to apply to the image data.
Methods
-------
__len__()
Returns the total number of samples in the dataset.
__getitem__(idx)
Returns the image and corresponding labels at the given index.
labels()
Returns a list of the target labels.
"""
def __init__(self, annotations_file, img_dir, name_column, target_columns, transform=None):
"""
Initialize the MultiLabeledImagesDataset.
Parameters
----------
annotations_file : str
The file path to the annotations file in CSV format.
img_dir : str
The directory path containing the images.
name_column : str
The name of the column in the annotations file that contains the image names.
target_columns : str or list[str]
The column name(s) of the target labels in the annotations file.
transform : torchvision.transforms
A transform to apply to the image data. Default is None.
"""
self.img_labels = pd.read_csv(annotations_file, dtype={name_column: str})
if isinstance(target_columns, Iterable):
selected_columns = [name_column] + list(target_columns)
else:
selected_columns = [name_column, target_columns]
self.img_labels = self.img_labels[selected_columns]
self.img_dir = img_dir
self.transform = transform
[docs]
def __len__(self):
"""
Return the total number of samples in the dataset.
Returns
-------
int
The total number of samples.
"""
return len(self.img_labels)
[docs]
def __getitem__(self, idx):
"""
Get the image and corresponding labels at the given index.
Parameters
----------
idx : int
The index of the sample to retrieve.
Returns
-------
tuple
A tuple containing the image and corresponding labels.
"""
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = pil_loader(img_path)
labels = torch.from_numpy(self.img_labels.iloc[idx, 1:].to_numpy(dtype=np.int8))
if self.transform:
image = self.transform(image)
return image, labels
[docs]
def labels(self):
"""
Return a list of the target labels.
Returns
-------
list
A list of target labels.
"""
return list(self.img_labels.columns[1:])
[docs]
class SemiSupervisedImagesDataset(MultiLabeledImagesDataset):
"""
A PyTorch dataset class for semi-supervised multi-labeled image data, inheriting from MultiLabeledImagesDataset.
Attributes
----------
img_labels : pd.DataFrame
A pandas DataFrame containing the image annotations.
img_dir : str
The directory path containing the images.
transform : torchvision.transforms
A transform to apply to the image data.
unlabeled_idx : numpy.ndarray
An array containing the indices of unlabeled samples.
Methods
-------
__init__(annotations_file, img_dir, name_column, target_columns, unlabeled_samples, transform=None)
Initialize the SemiSupervisedImagesDataset.
__getitem__(idx)
Get the image, corresponding labels, and unlabeled flag at the given index.
separate_unlabeled(x_raw, y_raw, is_unlabeled)
Separate the labeled and unlabeled samples from the given data.
"""
[docs]
def __init__(self, annotations_file, img_dir, name_column, target_columns, unlabeled_samples, transform=None):
"""
Initialize the SemiSupervisedImagesDataset.
Parameters
----------
annotations_file : str
The file path to the annotations file in CSV format.
img_dir : str
The directory path containing the images.
name_column : str
The name of the column in the annotations file that contains the image names.
target_columns : str or list[str]
The column name(s) of the target labels in the annotations file.
unlabeled_samples : int or float
The number of unlabeled samples to include. If float, it represents the fraction of unlabeled samples.
transform : torchvision.transforms
A transform to apply to the image data. Default is None.
Raises
------
ValueError
If the value of the parameter 'unlabeled_samples' is invalid.
"""
super().__init__(annotations_file, img_dir, name_column, target_columns, transform=transform)
if isinstance(unlabeled_samples, int):
self.unlabeled_idx = np.random.permutation(np.arange(0, len(self.img_labels)))[:unlabeled_samples]
elif isinstance(unlabeled_samples, float) and unlabeled_samples <= 1.0:
self.unlabeled_idx = np.random.permutation(np.arange(0, len(self.img_labels)))[:int(len(self.img_labels) *
unlabeled_samples)]
else:
raise ValueError("Invalid value of the parameter: unlabeled samples.")
self.img_labels.loc[:, 'Unlabeled'] = 0
self.img_labels.loc[self.unlabeled_idx, 'Unlabeled'] = 1
[docs]
def __getitem__(self, idx):
"""
Get the image, corresponding labels, and unlabeled flag at the given index.
Parameters
----------
idx : int
The index of the sample to retrieve.
Returns
-------
tuple
A tuple containing the image, corresponding labels, and unlabeled flag.
"""
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = pil_loader(img_path)
labels = torch.from_numpy(self.img_labels.iloc[idx, 1:-1].to_numpy(dtype=np.int8))
is_unlabeled = torch.from_numpy(self.img_labels.iloc[idx, -1:].to_numpy(dtype=np.int8))
if self.transform:
image = self.transform(image)
return image, labels, is_unlabeled
[docs]
@staticmethod
def separate_unlabeled(x_raw, y_raw, is_unlabeled):
"""
Separate the labeled and unlabeled samples from the given data.
Parameters
----------
x_raw : torch.Tensor
The input data.
y_raw : torch.Tensor
The target labels.
is_unlabeled : torch.Tensor
The unlabeled flags indicating whether a sample is labeled (0) or unlabeled (1).
Returns
-------
tuple
A tuple containing the labeled data, labeled target labels, unlabeled data, and unlabeled target labels.
"""
unlabeled_idx = torch.where(is_unlabeled == 1)
labeled_idx = torch.where(is_unlabeled == 0)
x, y = x_raw[labeled_idx[0]], y_raw[labeled_idx[0]]
x_unlab, y_unlab = x_raw[unlabeled_idx[0]], y_raw[unlabeled_idx[0]]
return x, y, x_unlab, y_unlab
[docs]
def create_dataloader(path_to_csv, path_to_images, image_names_column, target_columns,
batch_size, num_workers, transformation, unlabeled_samples=None):
"""
Create a PyTorch DataLoader for loading the multi-labeled image dataset.
Parameters
----------
path_to_csv : str
The file path to the annotations file in CSV format.
path_to_images : str
The directory path containing the images.
image_names_column : str
The name of the column in the annotations file that contains the image names.
target_columns : str or list[str]
The column name(s) of the target labels in the annotations file.
batch_size : int
The batch size for the DataLoader.
num_workers : int
The number of worker processes to use for data loading.
transformation : torchvision.transforms
A transform to apply to the image data.
unlabeled_samples : int or float, optional
The number of unlabeled samples to include. If float, it represents the fraction of unlabeled samples.
Default is None.
Returns
-------
torch.utils.data.DataLoader
A PyTorch DataLoader for the multi-labeled image dataset.
Raises
------
ValueError
If the value of the parameter 'unlabeled_samples' is invalid.
"""
if unlabeled_samples is None:
data = MultiLabeledImagesDataset(path_to_csv, path_to_images, image_names_column, target_columns,
transformation)
else:
data = SemiSupervisedImagesDataset(path_to_csv, path_to_images, image_names_column, target_columns,
unlabeled_samples, transformation)
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
return dataloader