fedjax.datasets

fedjax datasets.

fedjax.datasets.cifar100

Federated cifar100.

fedjax.datasets.emnist

Federated EMNIST.

fedjax.datasets.shakespeare

Federated Shakespeare.

fedjax.datasets.stackoverflow

Federated stackoverflow.

CIFAR-100

Federated cifar100.

fedjax.datasets.cifar100.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.cifar100.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed cifar100 splits.

Features:

  • x: [N, 32, 32, 3] uint8 pixels.

  • y: [N] int32 labels in the range [0, 100).

Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:

import functools
from fedjax.datasets import cifar100
# Load partially preprocessed splits.
train, test = cifar100.load_data()
# Preprocessing for training.
train_for_train = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=True))
# Preprocessing for eval.
train_for_eval = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))
test = test.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))

Features after final preprocessing:

  • x: [N, 32, 32, 3] float32 preprocessed pixels.

  • y: [N] int32 labels in the range [0, 100).

Note: preprocess_batch is just a convenience wrapper around preprocess_image() so that it can be used with fedjax.FederatedData.preprocess_batch().

Parameters
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

Tuple[FederatedData, FederatedData]

Returns

A (train, test) tuple of federated data.

fedjax.datasets.cifar100.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a cifar100 split.

Features:

  • image: [N, 32, 32, 3] uint8 pixels.

  • coarse_label: [N] int64 coarse labels in the range [0, 20).

  • label: [N] int64 labels in the range [0, 100).

Parameters
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

FederatedData

Returns

FederatedData.

fedjax.datasets.cifar100.preprocess_image(image, is_train)[source]

Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.

Preprocessing procedure and values taken from pytorch-cifar.

Parameters
  • image (ndarray) – [N, 32, 32, 3] uint8 pixels.

  • is_train (bool) – Whether we are preprocessing for training or eval.

Return type

ndarray

Returns

Processed [N, 32, 32, 3] float32 pixels.

EMNIST

Federated EMNIST.

fedjax.datasets.emnist.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.emnist.domain_id(client_id)[source]

Returns domain id for client id.

Domain ids are based on the NIST data source, where examples were collected from two sources: Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS). For more details, see the NIST documentation.

Parameters

client_id (bytes) – Client id of the format [16-byte hex hash]:f[4-digit integer]_[2-digit integer] or f[4-digit integer]_[2-digit integer].

Return type

int

Returns

Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).

fedjax.datasets.emnist.load_data(only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads processed EMNIST train and test splits.

Features:

  • x: [N, 28, 28, 1] float32 flipped image pixels.

  • y: [N] int32 classification label.

  • domain_id: [N] int32 domain id (see domain_id()).

Parameters
  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

Tuple[FederatedData, FederatedData]

Returns

Train and test splits as FederatedData.

fedjax.datasets.emnist.load_split(split, only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads an unprocessed federated emnist split.

Features:

  • pixels: [N, 28, 28] float32 image pixels.

  • label: [N] int32 classification label.

Parameters
  • split (str) – Name of the split. One of SPLITS.

  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

FederatedData

Returns

FederatedData.

Shakespeare

Federated Shakespeare.

fedjax.datasets.shakespeare.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.shakespeare.load_data(sequence_length=80, mode='sqlite', cache_dir=None)[source]

Loads preprocessed shakespeare splits.

Preprocessing is done using fedjax.FederatedData.preprocess_client() and preprocess_client().

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

Parameters
  • sequence_length (int) – The fixed sequence length after preprocessing.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

Tuple[FederatedData, FederatedData]

Returns

A (train, held_out, test) tuple of federated data.

fedjax.datasets.shakespeare.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a shakespeare split.

Features:

  • snippets: [N] bytes array of snippet text.

Parameters
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

FederatedData

Returns

FederatedData.

fedjax.datasets.shakespeare.preprocess_client(client_id, examples, sequence_length)[source]

Turns snippets into sequences of integer labels.

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

All snippets in a client dataset are first joined into a single sequence (with BOS/EOS added), and then split into pairs of sequence_length chunks for language model training. For example, with sequence_length=3, [b’ABCD’, b’E’] becomes:

Input sequences:  [[BOS, A, B], [C, D, EOS],   [BOS, E, PAD]]
Output seqeunces: [[A, B, C],   [D, EOS, BOS], [E, EOS, PAD]]

Note: This is not equivalent to the TensorFlow Federated text generation tutorial (The processing logic there loses ~1/sequence_length portion of the tokens).

Parameters
  • client_id (bytes) – Not used.

  • examples (Mapping[str, ndarray]) – Unprocessed examples (e.g. from load_split()).

  • sequence_length (int) – The fixed sequence length after preprocessing.

Return type

Mapping[str, ndarray]

Returns

Processed examples.

Stack Overflow

Federated stackoverflow.

fedjax.datasets.stackoverflow.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.stackoverflow.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed stackoverflow splits.

Features:

  • domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).

  • tokens: [N] bytes array. Space separated list of tokens.

To convert tokens into padded/truncated integer labels, use a StackoverflowTokenizer. For example,:

from fedjax.core.datasets import stackoverflow
# Load partially preprocessed splits.
train, held_out, test = stackoverflow.load_data()
# Apply tokenizer during batching.
tokenizer = stackoverflow.StackoverflowTokenizer()
train_max_length, eval_max_length = 20, 30
train_for_train = train.preprocess_batch(
    tokenizer.as_preprocess_batch(train_max_length))
train_for_eval = train.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
held_out = held_out.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
test = test.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))

Features after tokenization:

  • domain_id: Same as before.

  • x: [N, max_length] int32 array of padded/truncated input labels.

  • y: [N, max_length] int32 array of padded/truncated output labels.

Parameters
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

Tuple[FederatedData, FederatedData, FederatedData]

Returns

A (train, held_out, test) tuple of federated data.

fedjax.datasets.stackoverflow.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a stackoverflow split.

All bytes arrays are stored with dtype=np.object.

Features:

  • creation_date: [N] bytes array. Textual timestamp, e.g. b’2018-02-28 19:06:18.34 UTC’.

  • title: [N] bytes array. The title of a post.

  • score: [N] int64 array. The score of a post.

  • tags: [N] bytes array. ‘|’ separated list of tags, e.g. b’mysql|join’.

  • tokens: [N] bytes array. Space separated list of tokens.

  • type: [N] bytes array. Either b’question’ or b’answer’.

Parameters
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type

FederatedData

Returns

FederatedData.

class fedjax.datasets.stackoverflow.StackoverflowTokenizer(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Tokenizer for the tokens feature in stackoverflow.

See load_data() for examples.

__init__(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Initializes a tokenizer.

Parameters
  • vocab (Optional[List[str]]) – Optional vocabulary. If specified, default_vocab_size is ignored. If None, default_vocab_size is used to load the standard vocabulary. This vocabulary should NOT have special tokens PAD, EOS, BOS, and OOV. The special tokens are added and handled automatically by the tokenizer. The preprocessed examples will have vocabulary size len(vocab) + 3 + num_oov_buckets.

  • default_vocab_size (Optional[int]) – Number of words in the default vocabulary. This is only used when vocab is not specified. The preprocessed examples will have vocabulary size default_vocab_size + 3 + num_oov_buckets with 3 special labels: 0 (PAD), 1 (BOS), 2 (EOS), and num_oov_buckets OOV labels starting at default_vocab_size + 3.

  • num_oov_buckets (int) – Number of out of vocabulary buckets.

as_preprocess_batch(max_length)[source]

Creates a preprocess_batch function.

Parameters

max_length (int) – The length to pad x/y sequences to. Sequences longer than this are also truncated to this length.

Return type

Callable[[Mapping[str, ndarray]], Mapping[str, ndarray]]

Returns

A function that can be used with FederatedData.preprocess_batch().

create_token_to_ids_fn(max_length)[source]

Creates a Tf function that tokenizes tokens.

Parameters

max_length (int) – The length to pad x/y sequences to. Sequences longer than this are also truncated to this length.

Returns

A function that uses tensorflow ops to tokenize tokens.