Federated datasets

Open in Colab

This tutorial introduces datasets in FedJAX and how to work with them. By completing this tutorial, we’ll learn about the best practices for working with datasets.

NOTE: For datasets, we operate over NumPy arrays NOT JAX arrays.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
# !pip install tensorflow_datasets
import functools
import itertools
import fedjax
import numpy as np

What are datasets in federated learning?

In the context of federated learning (FL), data is decentralized across clients, with each client having their own local set of examples. In light of this, we refer to two levels of organization for datasets:

  • Federated dataset: A collection of clients, each with their own local datasets and metadata

  • Client dataset: The set of local examples for a particular client

We can think of federated data as a mapping from client ids to client datasets and client datasets as a list of examples.

federated_data = {
  'client0': ['a', 'b', 'c'],
  'client1': ['d', 'e'],
}

Federated datasets

FedJAX comes packaged with multiple federated datasets, and we will look at the Shakespeare dataset as an example. The Shakespeare dataset is created from The Complete Works of Shakespeare, by treating each character in the play as a “client”, and their dialogue lines as the examples.

FedJAX organizes federated datasets as Python modules. load_data() from a dataset module loads all predefined splits together as fedjax.FederatedData objects, the interface for accessing all federated datasets. In the case of the Shakespeare dataset, load_data() returns two splits: train and test.

# We cap max sentence length to 8.
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_train.sqlite' to '/tmp/.cache/fedjax/shakespeare_train.sqlite'
100%, elapsed: 0s
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_test.sqlite' to '/tmp/.cache/fedjax/shakespeare_test.sqlite'
100%, elapsed: 0s

fedjax.FederatedData provides methods for accessing metadata about the federated dataset, like the total number of clients, client ids, and number of examples for each client.

As seen in the output, there are 715 total clients in the Shakespeare dataset. Each client has a unique client ID that can be used to query metadata about that client such as the number of examples that client has.

print('num_clients =', train_fd.num_clients())

# train_fd.client_ids() is a generator of client ids.
# itertools has efficient and convenient functions for working with generators.
for client_id in itertools.islice(train_fd.client_ids(), 3):
  print('client_id =', client_id)
  print('# examples =', train_fd.client_size(client_id))
num_clients = 715
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 5
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 13
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 15

As we notice, the client ids start with a random set of bits. This is to ensure that one can easily slice a FederatedData to obtain a random subset. While the client ids returned above look sorted, there is no such guarantee in general.

# Slicing are based on the lexicographic order of client ids.
train_fd_0 = train_fd.slice(start=b'0', stop=b'1')
print('num_clients whose id starts with 0 =', train_fd_0.num_clients())
num_clients whose id starts with 0 = 47

Client datasets

We can query the dataset for a client from a federated dataset using their client ID and fedjax.FederatedData.get_client(). The output of fedjax.FederatedData.get_client() is a fedjax.ClientDataset. A fedjax.ClientDataset object

  • Stores all the examples for a given client and any preprocessing that should be applied on batches.

  • Provides methods for batching, shuffling, and iterating over preprocessed examples.

In other words, ClientDataset = examples + preprocessor.

client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
print(client_dataset)
<fedjax.core.client_datasets.ClientDataset object at 0x7ff730f5afd0>

FedJAX assumes that an individual client dataset is small and can easily fit in memory. This assumption is also reflected in many of FedJAX’s design decisions. The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).

FedJAX uses a column based representation when loading a dataset into memory.

  • Each column is a NumPy array x of rank at least 1, where x[i, ...] is the value of this feature for the i-th example.

  • The complete set of examples is a dict-like object, from str feature names, to the corresponding column values.

Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch.

In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.

The preprocessed examples of a single client can be viewed by calling all_examples().

client_dataset.all_examples()
{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12],
        [75, 54, 74, 19, 16, 66, 47,  3],
        [16, 67,  8, 67, 71, 47,  7, 61],
        [16, 14,  4, 67, 47, 16, 48, 67],
        [84, 67, 47,  7, 67, 48, 16, 12],
        [78, 29, 75, 78, 33, 16, 66, 47],
        [ 3, 16, 75, 73, 29, 11, 75, 76],
        [32, 19, 65, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [ 7,  4, 67, 16, 51, 48, 68,  7],
        [26, 47, 27, 42, 16,  7,  4, 67],
        [72, 16, 48, 67, 27, 23, 71, 67],
        [65, 29, 79, 76, 51, 74, 12, 75],
        [54, 74, 19, 16, 66, 47,  3, 16],
        [67,  8, 67, 71, 47,  7, 61, 16],
        [14,  4, 67, 47, 16, 48, 67, 84],
        [67, 47,  7, 67, 48, 16, 12, 78],
        [29, 75, 78, 33, 16, 66, 47,  3],
        [16, 75, 73, 29, 11, 75, 76, 32],
        [19, 65, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0]], dtype=int32)}

For Shakespeare, we are training a character-level language model, where the task is next character prediction, so the features are:

  • x is a list of right-shifted sentences, e.g. sentence[:-1]

  • y is a list of left-shifted sentences, e.g. sentence[1:]

This way, y[i][j] corresponds to the next character after x[i][j].

examples = client_dataset.all_examples()
print('x', examples['x'][0])
print('y', examples['y'][0])
x [ 1 55 67 84 67 47  7 67]
y [55 67 84 67 47  7 67 48]

However, you probably noticed that x and y are arrays of integers not text. This is because fedjax.datasets.shakespeare.load_data() does some minimal preprocessing, such as a simple character look up that mapped characters to integer IDs. Later, we’ll go over how this preprocessing was applied and how to add your own custom preprocessing.

For comparison, here’s the unprocessed version of the same client dataset.

# Unlike load_data(), load_split() always loads a single unprocessed split.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n                     with IMOGEN\n'],
       dtype=object)}

Accessing client datasets from fedjax.FederatedData

fedjax.FederatedData.get_client() works well for querying data for a single client for exploring the dataset. However, we often want to query for multiple client datasets at the same time. In most FL algorithms, tens to hundreds of clients particpate in each training round. If we are not careful, our code can spend a lot of time loading data, leaving the accelerators (GPU or TPU) to idle.

In light of this, fedjax.FederatedData offers more efficient methods for querying multiple client datasets.

We’ll go through each access method from the MOST efficient to the LEAST efficient.

clients() and shuffled_clients()

Fastest sequential read friendly access. As we stated earlier, the client ids are appended with random bits. Hene, even sequential reads will go over clients in a pseudo-random order.

# clients() and shuffled_clients() are sequential read friendly.
clients = train_fd.clients()
shuffled_clients = train_fd.shuffled_clients(buffer_size=100, seed=0)
print('clients =', clients)
print('shuffled_clients =', shuffled_clients)
clients = <generator object SQLiteFederatedData.clients at 0x7ff730f3d950>
shuffled_clients = <generator object SQLiteFederatedData.shuffled_clients at 0x7ff7310a8950>

They are generators, so we iterate over them to get the individual client datasets as tuples of (client_id, client_dataset).

clients() returns clients in an unspecified deterministic order. It is useful for going over the entire federated dataset for evaluation.

# We use itertools.islice to select first three clients.
for client_id, client_dataset in itertools.islice(clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 53
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 234
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 79

shuffled_clients() provides a stream of infinitely repeating shuffled client datasets, using buffered shuffling. It is suitable for training rounds where a nearly random shuffling is good enough.

print('shuffled_clients()')
for client_id, client_dataset in itertools.islice(shuffled_clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
shuffled_clients()
client_id = b'0a18c2501d441fef:THE_TRAGEDY_OF_KING_LEAR_FLUTE'
# examples = 115
client_id = b'136c5586b7271525:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GLOUCESTER'
# examples = 3804
client_id = b'0d642a9b4bb27187:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MESSENGER'
# examples = 775

get_clients()

Slower than clients() since it requires random read but uses prefetching to hide the latency of random read access. This also returns a generator of tuples of (client_id, client_dataset), in the order of the input client_ids.

client_ids = [
    b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN',
    b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN',
    b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
]
for client_id, client_dataset in train_fd.get_clients(client_ids):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
client_id = b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN'
# examples = 5
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
# examples = 19

get_client()

Slowest way of accessing client datasets. We usually reserve this method only for interactive exploration of a small number of clients.

client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
print('client_id =', client_id)
print('# examples =', len(train_fd.get_client(client_id)))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483

Batching client datasets

Next we’ll go over different methods of iterating over a fedjax.ClientDataset as batched examples. All the following methods can be invoked in 2 ways:

  • Using a hyperparams object: This is the recommended way in library code. batch_fn(hparams).

  • Using keyword arguments: The keyword arguments are used to construct a new hyperparams object, or override an existing one. batch_fn(batch_size=2) or batch_fn(hparams, batch_size=2) to override batch_size.

For the most part, we’ll use method 2 for this colab, but method 1 is more suitable for writing library code.

client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)

batch() for illustrations

Produces preprocessed batches in a fixed sequential order.

The final batch may contain fewer than batch_size examples. If used directly, that may result in a large number of JIT recompilations. Therefore we should use padded_batch() or shuffle_repeat_batch() instead in most scenarios.

Note here we are not talking about padding within an example, often done in processing sequence examples (e.g. the 0 labels below), but rather padding with “empty” examples in a batch.

batches = list(client_dataset.batch(batch_size=8))
batches[-1]
{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0]], dtype=int32)}

padded_batch() for evaluation

Produces preprocessed padded batches in a fixed sequential order for evaluation.

When the number of examples in the dataset is not a multiple of batch_size, the final batch may be smaller than batch_size. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled by num_batch_size_buckets.

# use list() to consume generator and store in memory.
padded_batches = list(client_dataset.padded_batch(batch_size=8, num_batch_size_buckets=3))
print('# batches =', len(padded_batches))
padded_batches[0]
# batches = 3
{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [ 7,  4, 67, 16, 51, 48, 68,  7],
        [26, 47, 27, 42, 16,  7,  4, 67],
        [72, 16, 48, 67, 27, 23, 71, 67],
        [65, 29, 79, 76, 51, 74, 12, 75]], dtype=int32),
 '__mask__': array([ True,  True,  True,  True,  True,  True,  True,  True])}

All batches contain an extra bool feature keyed by '__mask__'. batch['__mask__'][i] tells us whether the i-th example in this batch is an actual example (batch['__mask__'][i] == True), or a padding example (batch['__mask__'][i] == False).

We repeatedly halve the batch size up to num_batch_size_buckets - 1 times, until we find the smallest one that is also >= the size of the final batch. Therefore if batch_size < 2^num_batch_size_buckets, fewer bucket sizes will be actually used. This will be seen when we look at the final batch that only has 4 examples when the original batch size was 8.

padded_batches[-1]
{'__mask__': array([ True,  True,  True, False]),
 'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32)}

shuffle_repeat_batch() for training

Produces preprocessed batches in a shuffled and repeated order for training.

Shuffling is done without replacement, therefore for a dataset of N examples, the first ceil(N/batch_size) batches are guarranteed to cover the entire dataset. Unlike batch() or padded_batch(), batches from shuffle_repeat_batch() always contain exactly batch_size examples. Also unlike TensorFlow, that holds even when drop_remainder=False.

By default the iteration stops after the first epoch.

print('# batches')
len(list(client_dataset.shuffle_repeat_batch(batch_size=8)))
# batches
3

The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder) combination:

If both num_epochs and num_steps are None, the shuffle-repeat process continues forever.

infinite_bs = client_dataset.shuffle_repeat_batch(
    batch_size=8, num_epochs=None, num_steps=None)
for i, b in zip(range(6), infinite_bs):
  print(i)
0
1
2
3
4
5

If num_epochs is set and num_steps is None, as few batches as needed to go over the dataset this many passes are produced. Further,

  • If drop_remainder is False (the default), the final batch is filled with additionally sampled examples to contain batch_size examples.

  • If drop_remainder is True, the final batch is dropped if it contains fewer than batch_size examples. This may result in examples being skipped when num_epochs=1.

print('# batches w/ drop_remainder=False')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None))))
print('# batches w/ drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None, drop_remainder=True))))
# batches w/ drop_remainder=False
3
# batches w/ drop_remainder=True
2

If num_steps is set and num_steps is None, exactly this many batches are produced. drop_remainder has no effect in this case.

print('# batches w/ num_steps set and drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=None, num_steps=3, drop_remainder=True))))
# batches w/ num_steps set and drop_remainder=True
3

If both num_epochs and num_steps are set, the fewer number of batches between the two conditions are produced.

print('# batches w/ num_epochs and num_steps set')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=6))))
# batches w/ num_epochs and num_steps set
3

If reproducible iteration order is desired, a fixed seed can be used. When seed is None, repeated iteration over the same object may produce batches in a different order.

# Random shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=None))[0])
# Fixed shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=0))[0])
{'x': array([[78, 29, 75, 78, 33, 16, 66, 47],
       [ 7, 26, 47, 27, 42, 16,  7,  4]], dtype=int32), 'y': array([[29, 75, 78, 33, 16, 66, 47,  3],
       [26, 47, 27, 42, 16,  7,  4, 67]], dtype=int32)}
{'x': array([[16, 14,  4, 67, 47, 16, 48, 67],
       [48, 16, 13, 32, 33, 14, 11, 78]], dtype=int32), 'y': array([[14,  4, 67, 47, 16, 48, 67, 84],
       [16, 13, 32, 33, 14, 11, 78, 76]], dtype=int32)}

Preprocessing

Going from the raw features to features in a batch of examples ready for use in training/evalution often requires a few steps of preprocessing. Sometimes, we also want to add new preprocessing to an existing FederatedData.

Before going into the details of preprocessing, please note once again that all dataset related operations should be implemented in standard NumPy, not in JAX.

Preprocessing a batch of examples

The easiest way to add an additional preprocessing step is by appending a function that transforms a batch of examples to a FederatedData’s list of preprocessing transformations on batches.

Below, we add a new z feature that stores the parity of y for our Shakespeare examples.

# A preprocessing function should return a new dict of examples instead of
# modifying its input.
def parity_feature(examples):
  return {'z': examples['y'] % 2, **examples}

# preprocess_batch returns a new FederatedData object that has one more
# preprocessing step at the very end than the original.
train_fd_z = train_fd.preprocess_batch(parity_feature)
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
next(iter(train_fd_z.get_client(client_id).padded_batch(batch_size=4)))
{'z': array([[1, 1, 0, 1, 1, 1, 1, 0],
        [0, 1, 0, 1, 0, 1, 0, 0],
        [0, 1, 1, 0, 0, 1, 1, 0],
        [1, 1, 1, 0, 1, 1, 1, 0]], dtype=int32),
 'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16]], dtype=int32),
 '__mask__': array([ True,  True,  True,  True])}

Preprocessing at the client level

Sometimes we also need to do some preprocessing for the entire client dataset. For example, in the Shakespeare dataset, the raw features are just text strings, so we need to turn them into a NumPy array of chunks of integer labels just to be able to meaningfully batch them at all.

In most circumstances, adding a preprocessing step at the client level is unnecessary, and should be avoided, because the new preprocessing step added by preprocess_client is inserted into the middle of a chain of steps, just before all the batch level preprocessing steps registered by preprocess_batch. If not done carefully, a custom client level preprocessing can easily break the preprocessing chain.

Nevertheless, here’s an example of client level processing for certain rare cases.

# Load unpreprocessed data.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n                     with IMOGEN\n'],
       dtype=object)}

The actual client level preprocessing in the Shakespeare dataset is a bit involved, so let’s do something simpler: we shall join all the snippets, and then split and pad the integer byte values into bounded length sequences.

# We don't actually need client_id, but `FederatedData` supplies it so that
# different processing based on clients can be done.
def truncate_and_cast(client_id, examples, max_length=10):
  labels = list(b''.join(examples['snippets']))
  num_sequences = (len(labels) + max_length - 1) // max_length
  padded = np.zeros((num_sequences, max_length), dtype=np.int32)
  for i in range(num_sequences):
    chars = labels[i * max_length:(i + 1) * max_length]
    padded[i, :len(chars)] = chars
  return {'snippets': padded}


partial_fd = raw_fd.preprocess_client(truncate_and_cast)
partial_fd.get_client(client_id).all_examples()
{'snippets': array([[ 82, 101,  45, 101, 110, 116, 101, 114,  32,  80],
        [ 79,  83,  84,  72,  85,  77,  85,  83,  44,  32],
        [ 97, 110, 100,  32, 115, 101,  99, 111, 110, 100],
        [115,  32, 116, 104, 101,  32,  66, 114, 105, 116],
        [111, 110, 115,  59,  32, 116, 104, 101, 121,  32],
        [114, 101, 115,  99, 117, 101,  10,  67,  89,  77],
        [ 66,  69,  76,  73,  78,  69,  44,  32,  97, 110],
        [100,  32, 101, 120, 101, 117, 110, 116,  46,  32],
        [ 84, 104, 101, 110,  32, 114, 101,  45, 101, 110],
        [116, 101, 114,  32,  76,  85,  67,  73,  85,  83],
        [ 32,  97, 110, 100,  32,  73,  65,  67,  72,  73],
        [ 77,  79,  44,  10,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32, 119, 105, 116, 104,  32],
        [ 73,  77,  79,  71,  69,  78,  10,   0,   0,   0]], dtype=int32)}

Now, we can add another batch level preprocessor to produce x and y labels.

def snippets_to_xy(examples):
  snippets = examples['snippets']
  return {'x': snippets[:, :-1], 'y': snippets[:, 1:]}


partial_fd.preprocess_batch(snippets_to_xy).get_client(client_id).all_examples()
{'x': array([[ 82, 101,  45, 101, 110, 116, 101, 114,  32],
        [ 79,  83,  84,  72,  85,  77,  85,  83,  44],
        [ 97, 110, 100,  32, 115, 101,  99, 111, 110],
        [115,  32, 116, 104, 101,  32,  66, 114, 105],
        [111, 110, 115,  59,  32, 116, 104, 101, 121],
        [114, 101, 115,  99, 117, 101,  10,  67,  89],
        [ 66,  69,  76,  73,  78,  69,  44,  32,  97],
        [100,  32, 101, 120, 101, 117, 110, 116,  46],
        [ 84, 104, 101, 110,  32, 114, 101,  45, 101],
        [116, 101, 114,  32,  76,  85,  67,  73,  85],
        [ 32,  97, 110, 100,  32,  73,  65,  67,  72],
        [ 77,  79,  44,  10,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32, 119, 105, 116, 104],
        [ 73,  77,  79,  71,  69,  78,  10,   0,   0]], dtype=int32),
 'y': array([[101,  45, 101, 110, 116, 101, 114,  32,  80],
        [ 83,  84,  72,  85,  77,  85,  83,  44,  32],
        [110, 100,  32, 115, 101,  99, 111, 110, 100],
        [ 32, 116, 104, 101,  32,  66, 114, 105, 116],
        [110, 115,  59,  32, 116, 104, 101, 121,  32],
        [101, 115,  99, 117, 101,  10,  67,  89,  77],
        [ 69,  76,  73,  78,  69,  44,  32,  97, 110],
        [ 32, 101, 120, 101, 117, 110, 116,  46,  32],
        [104, 101, 110,  32, 114, 101,  45, 101, 110],
        [101, 114,  32,  76,  85,  67,  73,  85,  83],
        [ 97, 110, 100,  32,  73,  65,  67,  72,  73],
        [ 79,  44,  10,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32, 119, 105, 116, 104,  32],
        [ 77,  79,  71,  69,  78,  10,   0,   0,   0]], dtype=int32)}

In memory federated datasets

In many scenarios, it is desirable to create a small custom federated dataset from a collection of NumPy arrays for quick experimentation. FedJAX provides fedjax.InMemoryFederatedData to create small custom datasets. fedjax.InMemoryFederatedData takes a dictionary of numpy examples keyed by client id and creates a fedjax.FederatedData that is compatible with the rest of the library. We illustrate this below with a simple example.

# Obtain MNIST dataset from tensorflow and convert to numpy format.
import tensorflow_datasets as tfds
(ds_train, ds_test) = tfds.load('mnist',
                                split=['train', 'test'],
                                shuffle_files=True,
                                as_supervised=True,
                                with_info=False)
features, labels = list(ds_train.batch(60000).as_numpy_iterator())[0]
print('features shape', features.shape)
print('labels shape', labels.shape)

# Randomly split dataset into 100 clients and load them to a dictionary.
indices = np.random.randint(100, size=60000)
client_id_to_dataset_mapping = {}
for i in range(100):
  client_id_to_dataset_mapping[i] = {'x': features[indices==i, :, : , :],
                                     'y': labels[indices==i]}

# Create fedjax.InMemoryDataset.
iid_mnist_federated_data = fedjax.InMemoryFederatedData(
    client_id_to_dataset_mapping)

print('number of clients in iid_mnist_data',
      iid_mnist_federated_data.num_clients())
features shape (60000, 28, 28, 1)
labels shape (60000,)
number of clients in iid_mnist_data 100

Recap

In this tutorial, we have covered the following:

  1. Using fedjax.FederatedData.

  2. Different ways of batching client datasets.

  3. Different ways of processing client datasets.

  4. Creating small custom federated datasets.