fedjax.datasets

fedjax datasets.

fedjax.datasets.cifar100

Federated cifar100.

fedjax.datasets.emnist

Federated EMNIST.

fedjax.datasets.shakespeare

Federated Shakespeare.

fedjax.datasets.stackoverflow

Federated stackoverflow.

CIFAR-100

Federated cifar100.

fedjax.datasets.cifar100.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.cifar100.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed cifar100 splits.

Features:

  • x: [N, 32, 32, 3] uint8 pixels.

  • y: [N] int32 labels in the range [0, 100).

Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:

import functools
from fedjax.datasets import cifar100
# Load partially preprocessed splits.
train, test = cifar100.load_data()
# Preprocessing for training.
train_for_train = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=True))
# Preprocessing for eval.
train_for_eval = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))
test = test.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))

Features after this preprocessing:

  • x: [N, 32, 32, 3] float32 preprocessed pixels.

  • y: [N] int32 labels in the range [0, 100).

Alternatively, you can apply the same preprocessing as TensorFlow Federated following tff.simulation.baselines.cifar100.create_image_classification_task. For example,:

from fedjax.datasets import cifar100
train, test = cifar100.load_data()
train = train.preprocess_batch(preprocess_batch_tff)
test = test.preprocess_batch(preprocess_batch_tff)

Features after this preprocessing:

  • x: [N, 24, 24, 3] float32 preprocessed pixels.

  • y: [N] int32 labels in the range [0, 100).

Note: preprocess_batch and preprocess_batch_tff are just convenience wrappers around preprocess_image() and preprocess_image_tff(), respectively, for use with fedjax.FederatedData.preprocess_batch().

Parameters:
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

A (train, test) tuple of federated data.

fedjax.datasets.cifar100.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a cifar100 split.

Features:

  • image: [N, 32, 32, 3] uint8 pixels.

  • coarse_label: [N] int64 coarse labels in the range [0, 20).

  • label: [N] int64 labels in the range [0, 100).

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

fedjax.datasets.cifar100.preprocess_image(image, is_train)[source]

Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.

Preprocessing procedure and values taken from pytorch-cifar.

Parameters:
  • image (ndarray) – [N, 32, 32, 3] uint8 pixels.

  • is_train (bool) – Whether we are preprocessing for training or eval.

Return type:

ndarray

Returns:

Processed [N, 32, 32, 3] float32 pixels.

EMNIST

Federated EMNIST.

fedjax.datasets.emnist.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.emnist.domain_id(client_id)[source]

Returns domain id for client id.

Domain ids are based on the NIST data source, where examples were collected from two sources: Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS). For more details, see the NIST documentation.

Parameters:

client_id (bytes) – Client id of the format [16-byte hex hash]:f[4-digit integer]_[2-digit integer] or f[4-digit integer]_[2-digit integer].

Return type:

int

Returns:

Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).

fedjax.datasets.emnist.load_data(only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads processed EMNIST train and test splits.

Features:

  • x: [N, 28, 28, 1] float32 flipped image pixels.

  • y: [N] int32 classification label.

  • domain_id: [N] int32 domain id (see domain_id()).

Parameters:
  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

Train and test splits as FederatedData.

fedjax.datasets.emnist.load_split(split, only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads an unprocessed federated emnist split.

Features:

  • pixels: [N, 28, 28] float32 image pixels.

  • label: [N] int32 classification label.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

Shakespeare

Federated Shakespeare.

fedjax.datasets.shakespeare.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.shakespeare.load_data(sequence_length=80, mode='sqlite', cache_dir=None)[source]

Loads preprocessed shakespeare splits.

Preprocessing is done using fedjax.FederatedData.preprocess_client() and preprocess_client().

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

Parameters:
  • sequence_length (int) – The fixed sequence length after preprocessing.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

A (train, held_out, test) tuple of federated data.

fedjax.datasets.shakespeare.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a shakespeare split.

Features:

  • snippets: [N] bytes array of snippet text.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

fedjax.datasets.shakespeare.preprocess_client(client_id, examples, sequence_length)[source]

Turns snippets into sequences of integer labels.

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

All snippets in a client dataset are first joined into a single sequence (with BOS/EOS added), and then split into pairs of sequence_length chunks for language model training. For example, with sequence_length=3, [b’ABCD’, b’E’] becomes:

Input sequences:  [[BOS, A, B], [C, D, EOS],   [BOS, E, PAD]]
Output seqeunces: [[A, B, C],   [D, EOS, BOS], [E, EOS, PAD]]

Note: This is not equivalent to the TensorFlow Federated text generation tutorial (The processing logic there loses ~1/sequence_length portion of the tokens).

Parameters:
  • client_id (bytes) – Not used.

  • examples (Mapping[str, ndarray]) – Unprocessed examples (e.g. from load_split()).

  • sequence_length (int) – The fixed sequence length after preprocessing.

Return type:

Mapping[str, ndarray]

Returns:

Processed examples.

Stack Overflow

Federated stackoverflow.

fedjax.datasets.stackoverflow.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.stackoverflow.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed stackoverflow splits.

Features:

  • domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).

  • tokens: [N] bytes array. Space separated list of tokens.

To convert tokens into padded/truncated integer labels, use a StackoverflowTokenizer. For example,:

from fedjax.core.datasets import stackoverflow
# Load partially preprocessed splits.
train, held_out, test = stackoverflow.load_data()
# Apply tokenizer during batching.
tokenizer = stackoverflow.StackoverflowTokenizer()
train_max_length, eval_max_length = 20, 30
train_for_train = train.preprocess_batch(
    tokenizer.as_preprocess_batch(train_max_length))
train_for_eval = train.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
held_out = held_out.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
test = test.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))

Features after tokenization:

  • domain_id: Same as before.

  • x: [N, max_length] int32 array of padded/truncated input labels.

  • y: [N, max_length] int32 array of padded/truncated output labels.

Parameters:
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData, FederatedData]

Returns:

A (train, held_out, test) tuple of federated data.

fedjax.datasets.stackoverflow.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a stackoverflow split.

All bytes arrays are stored with dtype=np.object.

Features:

  • creation_date: [N] bytes array. Textual timestamp, e.g. b’2018-02-28 19:06:18.34 UTC’.

  • title: [N] bytes array. The title of a post.

  • score: [N] int64 array. The score of a post.

  • tags: [N] bytes array. ‘|’ separated list of tags, e.g. b’mysql|join’.

  • tokens: [N] bytes array. Space separated list of tokens.

  • type: [N] bytes array. Either b’question’ or b’answer’.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

class fedjax.datasets.stackoverflow.StackoverflowTokenizer(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Tokenizer for the tokens feature in stackoverflow.

See load_data() for examples.

__init__(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Initializes a tokenizer.

Parameters:
  • vocab (Optional[List[str]]) – Optional vocabulary. If specified, default_vocab_size is ignored. If None, default_vocab_size is used to load the standard vocabulary. This vocabulary should NOT have special tokens PAD, EOS, BOS, and OOV. The special tokens are added and handled automatically by the tokenizer. The preprocessed examples will have vocabulary size len(vocab) + 3 + num_oov_buckets.

  • default_vocab_size (Optional[int]) – Number of words in the default vocabulary. This is only used when vocab is not specified. The preprocessed examples will have vocabulary size default_vocab_size + 3 + num_oov_buckets with 3 special labels: 0 (PAD), 1 (BOS), 2 (EOS), and num_oov_buckets OOV labels starting at default_vocab_size + 3.

  • num_oov_buckets (int) – Number of out of vocabulary buckets.