basic_dataset#

Module Contents#

BaseDataset

Base dataset iterator

Subset

For data subset with different augmentation for different client.

CIFARSubset

For data subset with different augmentation for different client.

FedDataset

class BaseDataset(x, y)#

Bases: torch.utils.data.Dataset

Base dataset iterator

__len__()#
__getitem__(index)#
class Subset(dataset, indices, transform=None, target_transform=None)#

Bases: torch.utils.data.Dataset

For data subset with different augmentation for different client.

Parameters:
  • dataset (Dataset) – The whole Dataset

  • indices (List[int]) – Indices of sub-dataset to achieve from dataset.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

__getitem__(index)#

Get item

Parameters:

index (int) – index

Returns:

(image, target) where target is index of the target class.

__len__()#
class CIFARSubset(dataset, indices, transform=None, target_transform=None, to_image=True)#

Bases: Subset

For data subset with different augmentation for different client.

Parameters:
  • dataset (Dataset) – The whole Dataset

  • indices (List[int]) – Indices of sub-dataset to achieve from dataset.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

class FedDataset#

Bases: object

preprocess()#

Define the dataset partition process

abstract get_dataset(id, type='train')#

Get dataset class

Parameters:
  • id (int) – Client ID for the partial dataset to achieve.

  • type (str, optional) – Type of dataset, can be chosen from ["train", "val", "test"]. Defaults as "train".

Raises:

NotImplementedError

abstract get_dataloader(id, batch_size, type='train')#

Get data loader

__len__()#