Federated datasets
This tutorial introduces datasets in FedJAX and how to work with them. By completing this tutorial, we’ll learn about the best practices for working with datasets.
NOTE: For datasets, we operate over NumPy arrays NOT JAX arrays.
# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
# !pip install tensorflow_datasets
import functools
import itertools
import fedjax
import numpy as np
What are datasets in federated learning?
In the context of federated learning (FL), data is decentralized across clients, with each client having their own local set of examples. In light of this, we refer to two levels of organization for datasets:
Federated dataset: A collection of clients, each with their own local datasets and metadata
Client dataset: The set of local examples for a particular client
We can think of federated data as a mapping from client ids to client datasets and client datasets as a list of examples.
federated_data = {
'client0': ['a', 'b', 'c'],
'client1': ['d', 'e'],
}
Federated datasets
FedJAX comes packaged with multiple federated datasets, and we will look at the Shakespeare dataset as an example. The Shakespeare dataset is created from The Complete Works of Shakespeare, by treating each character in the play as a “client”, and their dialogue lines as the examples.
FedJAX organizes federated datasets as Python modules. load_data()
from a dataset module loads all predefined splits together as fedjax.FederatedData
objects, the interface for accessing all federated datasets. In the case of the Shakespeare dataset, load_data()
returns two splits: train and test.
# We cap max sentence length to 8.
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_train.sqlite' to '/tmp/.cache/fedjax/shakespeare_train.sqlite'
100%, elapsed: 0s
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_test.sqlite' to '/tmp/.cache/fedjax/shakespeare_test.sqlite'
100%, elapsed: 0s
fedjax.FederatedData
provides methods for accessing metadata about the federated dataset, like the total number of clients, client ids, and number of examples for each client.
As seen in the output, there are 715 total clients in the Shakespeare dataset. Each client has a unique client ID that can be used to query metadata about that client such as the number of examples that client has.
print('num_clients =', train_fd.num_clients())
# train_fd.client_ids() is a generator of client ids.
# itertools has efficient and convenient functions for working with generators.
for client_id in itertools.islice(train_fd.client_ids(), 3):
print('client_id =', client_id)
print('# examples =', train_fd.client_size(client_id))
num_clients = 715
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 5
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 13
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 15
As we notice, the client ids start with a random set of bits. This is to ensure that one can easily slice a FederatedData
to obtain a random subset. While the client ids returned above look sorted, there is no such guarantee in general.
# Slicing are based on the lexicographic order of client ids.
train_fd_0 = train_fd.slice(start=b'0', stop=b'1')
print('num_clients whose id starts with 0 =', train_fd_0.num_clients())
num_clients whose id starts with 0 = 47
Client datasets
We can query the dataset for a client from a federated dataset using their client ID and fedjax.FederatedData.get_client()
. The output of fedjax.FederatedData.get_client()
is a fedjax.ClientDataset
. A fedjax.ClientDataset
object
Stores all the examples for a given client and any preprocessing that should be applied on batches.
Provides methods for batching, shuffling, and iterating over preprocessed examples.
In other words, ClientDataset = examples + preprocessor.
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
print(client_dataset)
<fedjax.core.client_datasets.ClientDataset object at 0x7ff730f5afd0>
FedJAX assumes that an individual client dataset is small and can easily fit in memory. This assumption is also reflected in many of FedJAX’s design decisions. The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).
FedJAX uses a column based representation when loading a dataset into memory.
Each column is a NumPy array
x
of rank at least 1, wherex[i, ...]
is the value of this feature for thei
-th example.The complete set of examples is a dict-like object, from
str
feature names, to the corresponding column values.
Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch.
In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.
The preprocessed examples of a single client can be viewed by calling all_examples()
.
client_dataset.all_examples()
{'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27],
[16, 7, 4, 67, 16, 51, 48, 68],
[ 7, 26, 47, 27, 42, 16, 7, 4],
[67, 72, 16, 48, 67, 27, 23, 71],
[67, 65, 29, 79, 76, 51, 74, 12],
[75, 54, 74, 19, 16, 66, 47, 3],
[16, 67, 8, 67, 71, 47, 7, 61],
[16, 14, 4, 67, 47, 16, 48, 67],
[84, 67, 47, 7, 67, 48, 16, 12],
[78, 29, 75, 78, 33, 16, 66, 47],
[ 3, 16, 75, 73, 29, 11, 75, 76],
[32, 19, 65, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16],
[ 7, 4, 67, 16, 51, 48, 68, 7],
[26, 47, 27, 42, 16, 7, 4, 67],
[72, 16, 48, 67, 27, 23, 71, 67],
[65, 29, 79, 76, 51, 74, 12, 75],
[54, 74, 19, 16, 66, 47, 3, 16],
[67, 8, 67, 71, 47, 7, 61, 16],
[14, 4, 67, 47, 16, 48, 67, 84],
[67, 47, 7, 67, 48, 16, 12, 78],
[29, 75, 78, 33, 16, 66, 47, 3],
[16, 75, 73, 29, 11, 75, 76, 32],
[19, 65, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0]], dtype=int32)}
For Shakespeare, we are training a character-level language model, where the task is next character prediction, so the features are:
x
is a list of right-shifted sentences, e.g.sentence[:-1]
y
is a list of left-shifted sentences, e.g.sentence[1:]
This way, y[i][j]
corresponds to the next character after x[i][j]
.
examples = client_dataset.all_examples()
print('x', examples['x'][0])
print('y', examples['y'][0])
x [ 1 55 67 84 67 47 7 67]
y [55 67 84 67 47 7 67 48]
However, you probably noticed that x
and y
are arrays of integers not text. This is because fedjax.datasets.shakespeare.load_data()
does some minimal preprocessing, such as a simple character look up that mapped characters to integer IDs. Later, we’ll go over how this preprocessing was applied and how to add your own custom preprocessing.
For comparison, here’s the unprocessed version of the same client dataset.
# Unlike load_data(), load_split() always loads a single unprocessed split.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n with IMOGEN\n'],
dtype=object)}
Accessing client datasets from fedjax.FederatedData
fedjax.FederatedData.get_client()
works well for querying data for a single client for exploring the dataset. However, we often want to query for multiple client datasets at the same time. In most FL algorithms, tens to hundreds of clients particpate in each training round. If we are not careful, our code can spend a lot of time loading data, leaving the accelerators (GPU or TPU) to idle.
In light of this, fedjax.FederatedData
offers more efficient methods for querying multiple client datasets.
We’ll go through each access method from the MOST efficient to the LEAST efficient.
clients()
and shuffled_clients()
Fastest sequential read friendly access. As we stated earlier, the client ids are appended with random bits. Hene, even sequential reads will go over clients in a pseudo-random order.
# clients() and shuffled_clients() are sequential read friendly.
clients = train_fd.clients()
shuffled_clients = train_fd.shuffled_clients(buffer_size=100, seed=0)
print('clients =', clients)
print('shuffled_clients =', shuffled_clients)
clients = <generator object SQLiteFederatedData.clients at 0x7ff730f3d950>
shuffled_clients = <generator object SQLiteFederatedData.shuffled_clients at 0x7ff7310a8950>
They are generators, so we iterate over them to get the individual client datasets as tuples of (client_id, client_dataset).
clients()
returns clients in an unspecified deterministic order. It is useful for going over the entire federated dataset for evaluation.
# We use itertools.islice to select first three clients.
for client_id, client_dataset in itertools.islice(clients, 3):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 53
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 234
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 79
shuffled_clients()
provides a stream of infinitely repeating shuffled client datasets, using buffered shuffling. It is suitable for training rounds where a nearly random shuffling is good enough.
print('shuffled_clients()')
for client_id, client_dataset in itertools.islice(shuffled_clients, 3):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
shuffled_clients()
client_id = b'0a18c2501d441fef:THE_TRAGEDY_OF_KING_LEAR_FLUTE'
# examples = 115
client_id = b'136c5586b7271525:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GLOUCESTER'
# examples = 3804
client_id = b'0d642a9b4bb27187:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MESSENGER'
# examples = 775
get_clients()
Slower than clients()
since it requires random read but uses prefetching to hide the latency of random read access. This also returns a generator of tuples of (client_id, client_dataset), in the order of the input client_ids.
client_ids = [
b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN',
b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN',
b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
]
for client_id, client_dataset in train_fd.get_clients(client_ids):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
client_id = b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN'
# examples = 5
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
# examples = 19
get_client()
Slowest way of accessing client datasets. We usually reserve this method only for interactive exploration of a small number of clients.
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
print('client_id =', client_id)
print('# examples =', len(train_fd.get_client(client_id)))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
Batching client datasets
Next we’ll go over different methods of iterating over a fedjax.ClientDataset
as batched examples. All the following methods can be invoked in 2 ways:
Using a hyperparams object: This is the recommended way in library code.
batch_fn(hparams)
.Using keyword arguments: The keyword arguments are used to construct a new hyperparams object, or override an existing one.
batch_fn(batch_size=2)
orbatch_fn(hparams, batch_size=2)
to overridebatch_size
.
For the most part, we’ll use method 2 for this colab, but method 1 is more suitable for writing library code.
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
batch()
for illustrations
Produces preprocessed batches in a fixed sequential order.
The final batch may contain fewer than batch_size
examples. If used directly,
that may result in a large number of JIT recompilations. Therefore we
should use padded_batch()
or shuffle_repeat_batch()
instead in most scenarios.
Note here we are not talking about padding within an example, often done in processing sequence examples (e.g. the 0 labels below), but rather padding with “empty” examples in a batch.
batches = list(client_dataset.batch(batch_size=8))
batches[-1]
{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0]], dtype=int32),
'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0]], dtype=int32)}
padded_batch()
for evaluation
Produces preprocessed padded batches in a fixed sequential order for evaluation.
When the number of examples in the dataset is not a multiple of batch_size
,
the final batch may be smaller than batch_size
. This may lead to a large
number of JIT recompilations. This can be circumvented by padding the final
batch to a small number of fixed sizes controlled by num_batch_size_buckets
.
# use list() to consume generator and store in memory.
padded_batches = list(client_dataset.padded_batch(batch_size=8, num_batch_size_buckets=3))
print('# batches =', len(padded_batches))
padded_batches[0]
# batches = 3
{'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27],
[16, 7, 4, 67, 16, 51, 48, 68],
[ 7, 26, 47, 27, 42, 16, 7, 4],
[67, 72, 16, 48, 67, 27, 23, 71],
[67, 65, 29, 79, 76, 51, 74, 12]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16],
[ 7, 4, 67, 16, 51, 48, 68, 7],
[26, 47, 27, 42, 16, 7, 4, 67],
[72, 16, 48, 67, 27, 23, 71, 67],
[65, 29, 79, 76, 51, 74, 12, 75]], dtype=int32),
'__mask__': array([ True, True, True, True, True, True, True, True])}
All batches contain an extra bool feature keyed by '__mask__'
.
batch['__mask__'][i]
tells us whether the i
-th example in this batch
is an actual example (batch['__mask__'][i] == True
), or a padding
example (batch['__mask__'][i] == False
).
We repeatedly halve the batch size up to num_batch_size_buckets - 1
times, until
we find the smallest one that is also >= the size of the final batch. Therefore
if batch_size < 2^num_batch_size_buckets
, fewer bucket sizes will be actually
used. This will be seen when we look at the final batch that only has 4 examples when the original batch size was 8.
padded_batches[-1]
{'__mask__': array([ True, True, True, False]),
'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32),
'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)}
shuffle_repeat_batch()
for training
Produces preprocessed batches in a shuffled and repeated order for training.
Shuffling is done without replacement, therefore for a dataset of N examples,
the first ceil(N/batch_size)
batches are guarranteed to cover the entire
dataset. Unlike batch()
or padded_batch()
, batches from
shuffle_repeat_batch()
always contain exactly batch_size
examples. Also
unlike TensorFlow, that holds even when drop_remainder=False
.
By default the iteration stops after the first epoch.
print('# batches')
len(list(client_dataset.shuffle_repeat_batch(batch_size=8)))
# batches
3
The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder)
combination:
If both num_epochs
and num_steps
are None, the shuffle-repeat process continues forever.
infinite_bs = client_dataset.shuffle_repeat_batch(
batch_size=8, num_epochs=None, num_steps=None)
for i, b in zip(range(6), infinite_bs):
print(i)
0
1
2
3
4
5
If num_epochs
is set and num_steps
is None, as few batches as needed to go
over the dataset this many passes are produced. Further,
If
drop_remainder
is False (the default), the final batch is filled with additionally sampled examples to containbatch_size
examples.If
drop_remainder
is True, the final batch is dropped if it contains fewer thanbatch_size
examples. This may result in examples being skipped whennum_epochs=1
.
print('# batches w/ drop_remainder=False')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None))))
print('# batches w/ drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None, drop_remainder=True))))
# batches w/ drop_remainder=False
3
# batches w/ drop_remainder=True
2
If num_steps
is set and num_steps
is None, exactly this many batches are
produced. drop_remainder
has no effect in this case.
print('# batches w/ num_steps set and drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=None, num_steps=3, drop_remainder=True))))
# batches w/ num_steps set and drop_remainder=True
3
If both num_epochs
and num_steps
are set, the fewer number of batches
between the two conditions are produced.
print('# batches w/ num_epochs and num_steps set')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=6))))
# batches w/ num_epochs and num_steps set
3
If reproducible iteration order is desired, a fixed seed
can be used. When
seed
is None, repeated iteration over the same object may produce batches in a
different order.
# Random shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=None))[0])
# Fixed shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=0))[0])
{'x': array([[78, 29, 75, 78, 33, 16, 66, 47],
[ 7, 26, 47, 27, 42, 16, 7, 4]], dtype=int32), 'y': array([[29, 75, 78, 33, 16, 66, 47, 3],
[26, 47, 27, 42, 16, 7, 4, 67]], dtype=int32)}
{'x': array([[16, 14, 4, 67, 47, 16, 48, 67],
[48, 16, 13, 32, 33, 14, 11, 78]], dtype=int32), 'y': array([[14, 4, 67, 47, 16, 48, 67, 84],
[16, 13, 32, 33, 14, 11, 78, 76]], dtype=int32)}
Preprocessing
Going from the raw features to features in a batch of examples ready for use in training/evalution often requires a few steps of preprocessing. Sometimes, we also want to add new preprocessing to an existing FederatedData
.
Before going into the details of preprocessing, please note once again that all dataset related operations should be implemented in standard NumPy, not in JAX.
Preprocessing a batch of examples
The easiest way to add an additional preprocessing step is by appending a function that transforms a batch of examples to a FederatedData
’s list of preprocessing transformations on batches.
Below, we add a new z
feature that stores the parity of y
for our Shakespeare examples.
# A preprocessing function should return a new dict of examples instead of
# modifying its input.
def parity_feature(examples):
return {'z': examples['y'] % 2, **examples}
# preprocess_batch returns a new FederatedData object that has one more
# preprocessing step at the very end than the original.
train_fd_z = train_fd.preprocess_batch(parity_feature)
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
next(iter(train_fd_z.get_client(client_id).padded_batch(batch_size=4)))
{'z': array([[1, 1, 0, 1, 1, 1, 1, 0],
[0, 1, 0, 1, 0, 1, 0, 0],
[0, 1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 0]], dtype=int32),
'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16]], dtype=int32),
'__mask__': array([ True, True, True, True])}
Preprocessing at the client level
Sometimes we also need to do some preprocessing for the entire client dataset. For example, in the Shakespeare dataset, the raw features are just text strings, so we need to turn them into a NumPy array of chunks of integer labels just to be able to meaningfully batch them at all.
In most circumstances, adding a preprocessing step at the client level is unnecessary, and should be avoided, because the new preprocessing step added by preprocess_client
is inserted into the middle of a chain of steps, just before all the batch level preprocessing steps registered by preprocess_batch
. If not done carefully, a custom client level preprocessing can easily break the preprocessing chain.
Nevertheless, here’s an example of client level processing for certain rare cases.
# Load unpreprocessed data.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n with IMOGEN\n'],
dtype=object)}
The actual client level preprocessing in the Shakespeare dataset is a bit involved, so let’s do something simpler: we shall join all the snippets, and then split and pad the integer byte values into bounded length sequences.
# We don't actually need client_id, but `FederatedData` supplies it so that
# different processing based on clients can be done.
def truncate_and_cast(client_id, examples, max_length=10):
labels = list(b''.join(examples['snippets']))
num_sequences = (len(labels) + max_length - 1) // max_length
padded = np.zeros((num_sequences, max_length), dtype=np.int32)
for i in range(num_sequences):
chars = labels[i * max_length:(i + 1) * max_length]
padded[i, :len(chars)] = chars
return {'snippets': padded}
partial_fd = raw_fd.preprocess_client(truncate_and_cast)
partial_fd.get_client(client_id).all_examples()
{'snippets': array([[ 82, 101, 45, 101, 110, 116, 101, 114, 32, 80],
[ 79, 83, 84, 72, 85, 77, 85, 83, 44, 32],
[ 97, 110, 100, 32, 115, 101, 99, 111, 110, 100],
[115, 32, 116, 104, 101, 32, 66, 114, 105, 116],
[111, 110, 115, 59, 32, 116, 104, 101, 121, 32],
[114, 101, 115, 99, 117, 101, 10, 67, 89, 77],
[ 66, 69, 76, 73, 78, 69, 44, 32, 97, 110],
[100, 32, 101, 120, 101, 117, 110, 116, 46, 32],
[ 84, 104, 101, 110, 32, 114, 101, 45, 101, 110],
[116, 101, 114, 32, 76, 85, 67, 73, 85, 83],
[ 32, 97, 110, 100, 32, 73, 65, 67, 72, 73],
[ 77, 79, 44, 10, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 119, 105, 116, 104, 32],
[ 73, 77, 79, 71, 69, 78, 10, 0, 0, 0]], dtype=int32)}
Now, we can add another batch level preprocessor to produce x
and y
labels.
def snippets_to_xy(examples):
snippets = examples['snippets']
return {'x': snippets[:, :-1], 'y': snippets[:, 1:]}
partial_fd.preprocess_batch(snippets_to_xy).get_client(client_id).all_examples()
{'x': array([[ 82, 101, 45, 101, 110, 116, 101, 114, 32],
[ 79, 83, 84, 72, 85, 77, 85, 83, 44],
[ 97, 110, 100, 32, 115, 101, 99, 111, 110],
[115, 32, 116, 104, 101, 32, 66, 114, 105],
[111, 110, 115, 59, 32, 116, 104, 101, 121],
[114, 101, 115, 99, 117, 101, 10, 67, 89],
[ 66, 69, 76, 73, 78, 69, 44, 32, 97],
[100, 32, 101, 120, 101, 117, 110, 116, 46],
[ 84, 104, 101, 110, 32, 114, 101, 45, 101],
[116, 101, 114, 32, 76, 85, 67, 73, 85],
[ 32, 97, 110, 100, 32, 73, 65, 67, 72],
[ 77, 79, 44, 10, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 119, 105, 116, 104],
[ 73, 77, 79, 71, 69, 78, 10, 0, 0]], dtype=int32),
'y': array([[101, 45, 101, 110, 116, 101, 114, 32, 80],
[ 83, 84, 72, 85, 77, 85, 83, 44, 32],
[110, 100, 32, 115, 101, 99, 111, 110, 100],
[ 32, 116, 104, 101, 32, 66, 114, 105, 116],
[110, 115, 59, 32, 116, 104, 101, 121, 32],
[101, 115, 99, 117, 101, 10, 67, 89, 77],
[ 69, 76, 73, 78, 69, 44, 32, 97, 110],
[ 32, 101, 120, 101, 117, 110, 116, 46, 32],
[104, 101, 110, 32, 114, 101, 45, 101, 110],
[101, 114, 32, 76, 85, 67, 73, 85, 83],
[ 97, 110, 100, 32, 73, 65, 67, 72, 73],
[ 79, 44, 10, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 119, 105, 116, 104, 32],
[ 77, 79, 71, 69, 78, 10, 0, 0, 0]], dtype=int32)}
In memory federated datasets
In many scenarios, it is desirable to create a small custom federated dataset from a collection of NumPy arrays for quick experimentation. FedJAX provides fedjax.InMemoryFederatedData
to create small custom datasets. fedjax.InMemoryFederatedData
takes a dictionary of numpy examples keyed by client id and creates a fedjax.FederatedData
that is compatible with the rest of the library. We illustrate this below with a simple example.
# Obtain MNIST dataset from tensorflow and convert to numpy format.
import tensorflow_datasets as tfds
(ds_train, ds_test) = tfds.load('mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=False)
features, labels = list(ds_train.batch(60000).as_numpy_iterator())[0]
print('features shape', features.shape)
print('labels shape', labels.shape)
# Randomly split dataset into 100 clients and load them to a dictionary.
indices = np.random.randint(100, size=60000)
client_id_to_dataset_mapping = {}
for i in range(100):
client_id_to_dataset_mapping[i] = {'x': features[indices==i, :, : , :],
'y': labels[indices==i]}
# Create fedjax.InMemoryDataset.
iid_mnist_federated_data = fedjax.InMemoryFederatedData(
client_id_to_dataset_mapping)
print('number of clients in iid_mnist_data',
iid_mnist_federated_data.num_clients())
features shape (60000, 28, 28, 1)
labels shape (60000,)
number of clients in iid_mnist_data 100
Recap
In this tutorial, we have covered the following:
Using
fedjax.FederatedData
.Different ways of batching client datasets.
Different ways of processing client datasets.
Creating small custom federated datasets.