FedJAX documentation
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.
Federated datasets
This tutorial introduces datasets in FedJAX and how to work with them. By completing this tutorial, we’ll learn about the best practices for working with datasets.
NOTE: For datasets, we operate over NumPy arrays NOT JAX arrays.
# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
# !pip install tensorflow_datasets
import functools
import itertools
import fedjax
import numpy as np
What are datasets in federated learning?
In the context of federated learning (FL), data is decentralized across clients, with each client having their own local set of examples. In light of this, we refer to two levels of organization for datasets:
Federated dataset: A collection of clients, each with their own local datasets and metadata
Client dataset: The set of local examples for a particular client
We can think of federated data as a mapping from client ids to client datasets and client datasets as a list of examples.
federated_data = {
'client0': ['a', 'b', 'c'],
'client1': ['d', 'e'],
}
Federated datasets
FedJAX comes packaged with multiple federated datasets, and we will look at the Shakespeare dataset as an example. The Shakespeare dataset is created from The Complete Works of Shakespeare, by treating each character in the play as a “client”, and their dialogue lines as the examples.
FedJAX organizes federated datasets as Python modules. load_data()
from a dataset module loads all predefined splits together as fedjax.FederatedData
objects, the interface for accessing all federated datasets. In the case of the Shakespeare dataset, load_data()
returns two splits: train and test.
# We cap max sentence length to 8.
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_train.sqlite' to '/tmp/.cache/fedjax/shakespeare_train.sqlite'
100%, elapsed: 0s
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_test.sqlite' to '/tmp/.cache/fedjax/shakespeare_test.sqlite'
100%, elapsed: 0s
fedjax.FederatedData
provides methods for accessing metadata about the federated dataset, like the total number of clients, client ids, and number of examples for each client.
As seen in the output, there are 715 total clients in the Shakespeare dataset. Each client has a unique client ID that can be used to query metadata about that client such as the number of examples that client has.
print('num_clients =', train_fd.num_clients())
# train_fd.client_ids() is a generator of client ids.
# itertools has efficient and convenient functions for working with generators.
for client_id in itertools.islice(train_fd.client_ids(), 3):
print('client_id =', client_id)
print('# examples =', train_fd.client_size(client_id))
num_clients = 715
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 5
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 13
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 15
As we notice, the client ids start with a random set of bits. This is to ensure that one can easily slice a FederatedData
to obtain a random subset. While the client ids returned above look sorted, there is no such guarantee in general.
# Slicing are based on the lexicographic order of client ids.
train_fd_0 = train_fd.slice(start=b'0', stop=b'1')
print('num_clients whose id starts with 0 =', train_fd_0.num_clients())
num_clients whose id starts with 0 = 47
Client datasets
We can query the dataset for a client from a federated dataset using their client ID and fedjax.FederatedData.get_client()
. The output of fedjax.FederatedData.get_client()
is a fedjax.ClientDataset
. A fedjax.ClientDataset
object
Stores all the examples for a given client and any preprocessing that should be applied on batches.
Provides methods for batching, shuffling, and iterating over preprocessed examples.
In other words, ClientDataset = examples + preprocessor.
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
print(client_dataset)
<fedjax.core.client_datasets.ClientDataset object at 0x7ff730f5afd0>
FedJAX assumes that an individual client dataset is small and can easily fit in memory. This assumption is also reflected in many of FedJAX’s design decisions. The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).
FedJAX uses a column based representation when loading a dataset into memory.
Each column is a NumPy array
x
of rank at least 1, wherex[i, ...]
is the value of this feature for thei
-th example.The complete set of examples is a dict-like object, from
str
feature names, to the corresponding column values.
Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch.
In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.
The preprocessed examples of a single client can be viewed by calling all_examples()
.
client_dataset.all_examples()
{'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27],
[16, 7, 4, 67, 16, 51, 48, 68],
[ 7, 26, 47, 27, 42, 16, 7, 4],
[67, 72, 16, 48, 67, 27, 23, 71],
[67, 65, 29, 79, 76, 51, 74, 12],
[75, 54, 74, 19, 16, 66, 47, 3],
[16, 67, 8, 67, 71, 47, 7, 61],
[16, 14, 4, 67, 47, 16, 48, 67],
[84, 67, 47, 7, 67, 48, 16, 12],
[78, 29, 75, 78, 33, 16, 66, 47],
[ 3, 16, 75, 73, 29, 11, 75, 76],
[32, 19, 65, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16],
[ 7, 4, 67, 16, 51, 48, 68, 7],
[26, 47, 27, 42, 16, 7, 4, 67],
[72, 16, 48, 67, 27, 23, 71, 67],
[65, 29, 79, 76, 51, 74, 12, 75],
[54, 74, 19, 16, 66, 47, 3, 16],
[67, 8, 67, 71, 47, 7, 61, 16],
[14, 4, 67, 47, 16, 48, 67, 84],
[67, 47, 7, 67, 48, 16, 12, 78],
[29, 75, 78, 33, 16, 66, 47, 3],
[16, 75, 73, 29, 11, 75, 76, 32],
[19, 65, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 16],
[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0]], dtype=int32)}
For Shakespeare, we are training a character-level language model, where the task is next character prediction, so the features are:
x
is a list of right-shifted sentences, e.g.sentence[:-1]
y
is a list of left-shifted sentences, e.g.sentence[1:]
This way, y[i][j]
corresponds to the next character after x[i][j]
.
examples = client_dataset.all_examples()
print('x', examples['x'][0])
print('y', examples['y'][0])
x [ 1 55 67 84 67 47 7 67]
y [55 67 84 67 47 7 67 48]
However, you probably noticed that x
and y
are arrays of integers not text. This is because fedjax.datasets.shakespeare.load_data()
does some minimal preprocessing, such as a simple character look up that mapped characters to integer IDs. Later, we’ll go over how this preprocessing was applied and how to add your own custom preprocessing.
For comparison, here’s the unprocessed version of the same client dataset.
# Unlike load_data(), load_split() always loads a single unprocessed split.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n with IMOGEN\n'],
dtype=object)}
Accessing client datasets from fedjax.FederatedData
fedjax.FederatedData.get_client()
works well for querying data for a single client for exploring the dataset. However, we often want to query for multiple client datasets at the same time. In most FL algorithms, tens to hundreds of clients particpate in each training round. If we are not careful, our code can spend a lot of time loading data, leaving the accelerators (GPU or TPU) to idle.
In light of this, fedjax.FederatedData
offers more efficient methods for querying multiple client datasets.
We’ll go through each access method from the MOST efficient to the LEAST efficient.
clients()
and shuffled_clients()
Fastest sequential read friendly access. As we stated earlier, the client ids are appended with random bits. Hene, even sequential reads will go over clients in a pseudo-random order.
# clients() and shuffled_clients() are sequential read friendly.
clients = train_fd.clients()
shuffled_clients = train_fd.shuffled_clients(buffer_size=100, seed=0)
print('clients =', clients)
print('shuffled_clients =', shuffled_clients)
clients = <generator object SQLiteFederatedData.clients at 0x7ff730f3d950>
shuffled_clients = <generator object SQLiteFederatedData.shuffled_clients at 0x7ff7310a8950>
They are generators, so we iterate over them to get the individual client datasets as tuples of (client_id, client_dataset).
clients()
returns clients in an unspecified deterministic order. It is useful for going over the entire federated dataset for evaluation.
# We use itertools.islice to select first three clients.
for client_id, client_dataset in itertools.islice(clients, 3):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 53
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 234
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 79
shuffled_clients()
provides a stream of infinitely repeating shuffled client datasets, using buffered shuffling. It is suitable for training rounds where a nearly random shuffling is good enough.
print('shuffled_clients()')
for client_id, client_dataset in itertools.islice(shuffled_clients, 3):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
shuffled_clients()
client_id = b'0a18c2501d441fef:THE_TRAGEDY_OF_KING_LEAR_FLUTE'
# examples = 115
client_id = b'136c5586b7271525:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GLOUCESTER'
# examples = 3804
client_id = b'0d642a9b4bb27187:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MESSENGER'
# examples = 775
get_clients()
Slower than clients()
since it requires random read but uses prefetching to hide the latency of random read access. This also returns a generator of tuples of (client_id, client_dataset), in the order of the input client_ids.
client_ids = [
b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN',
b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN',
b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
]
for client_id, client_dataset in train_fd.get_clients(client_ids):
print('client_id =', client_id)
print('# examples =', len(client_dataset))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
client_id = b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN'
# examples = 5
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
# examples = 19
get_client()
Slowest way of accessing client datasets. We usually reserve this method only for interactive exploration of a small number of clients.
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
print('client_id =', client_id)
print('# examples =', len(train_fd.get_client(client_id)))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
Batching client datasets
Next we’ll go over different methods of iterating over a fedjax.ClientDataset
as batched examples. All the following methods can be invoked in 2 ways:
Using a hyperparams object: This is the recommended way in library code.
batch_fn(hparams)
.Using keyword arguments: The keyword arguments are used to construct a new hyperparams object, or override an existing one.
batch_fn(batch_size=2)
orbatch_fn(hparams, batch_size=2)
to overridebatch_size
.
For the most part, we’ll use method 2 for this colab, but method 1 is more suitable for writing library code.
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
batch()
for illustrations
Produces preprocessed batches in a fixed sequential order.
The final batch may contain fewer than batch_size
examples. If used directly,
that may result in a large number of JIT recompilations. Therefore we
should use padded_batch()
or shuffle_repeat_batch()
instead in most scenarios.
Note here we are not talking about padding within an example, often done in processing sequence examples (e.g. the 0 labels below), but rather padding with “empty” examples in a batch.
batches = list(client_dataset.batch(batch_size=8))
batches[-1]
{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0]], dtype=int32),
'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0]], dtype=int32)}
padded_batch()
for evaluation
Produces preprocessed padded batches in a fixed sequential order for evaluation.
When the number of examples in the dataset is not a multiple of batch_size
,
the final batch may be smaller than batch_size
. This may lead to a large
number of JIT recompilations. This can be circumvented by padding the final
batch to a small number of fixed sizes controlled by num_batch_size_buckets
.
# use list() to consume generator and store in memory.
padded_batches = list(client_dataset.padded_batch(batch_size=8, num_batch_size_buckets=3))
print('# batches =', len(padded_batches))
padded_batches[0]
# batches = 3
{'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27],
[16, 7, 4, 67, 16, 51, 48, 68],
[ 7, 26, 47, 27, 42, 16, 7, 4],
[67, 72, 16, 48, 67, 27, 23, 71],
[67, 65, 29, 79, 76, 51, 74, 12]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16],
[ 7, 4, 67, 16, 51, 48, 68, 7],
[26, 47, 27, 42, 16, 7, 4, 67],
[72, 16, 48, 67, 27, 23, 71, 67],
[65, 29, 79, 76, 51, 74, 12, 75]], dtype=int32),
'__mask__': array([ True, True, True, True, True, True, True, True])}
All batches contain an extra bool feature keyed by '__mask__'
.
batch['__mask__'][i]
tells us whether the i
-th example in this batch
is an actual example (batch['__mask__'][i] == True
), or a padding
example (batch['__mask__'][i] == False
).
We repeatedly halve the batch size up to num_batch_size_buckets - 1
times, until
we find the smallest one that is also >= the size of the final batch. Therefore
if batch_size < 2^num_batch_size_buckets
, fewer bucket sizes will be actually
used. This will be seen when we look at the final batch that only has 4 examples when the original batch size was 8.
padded_batches[-1]
{'__mask__': array([ True, True, True, False]),
'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
[28, 68, 7, 4, 16, 75, 76, 32],
[30, 74, 54, 65, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32),
'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
[68, 7, 4, 16, 75, 76, 32, 30],
[74, 54, 65, 2, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)}
shuffle_repeat_batch()
for training
Produces preprocessed batches in a shuffled and repeated order for training.
Shuffling is done without replacement, therefore for a dataset of N examples,
the first ceil(N/batch_size)
batches are guarranteed to cover the entire
dataset. Unlike batch()
or padded_batch()
, batches from
shuffle_repeat_batch()
always contain exactly batch_size
examples. Also
unlike TensorFlow, that holds even when drop_remainder=False
.
By default the iteration stops after the first epoch.
print('# batches')
len(list(client_dataset.shuffle_repeat_batch(batch_size=8)))
# batches
3
The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder)
combination:
If both num_epochs
and num_steps
are None, the shuffle-repeat process continues forever.
infinite_bs = client_dataset.shuffle_repeat_batch(
batch_size=8, num_epochs=None, num_steps=None)
for i, b in zip(range(6), infinite_bs):
print(i)
0
1
2
3
4
5
If num_epochs
is set and num_steps
is None, as few batches as needed to go
over the dataset this many passes are produced. Further,
If
drop_remainder
is False (the default), the final batch is filled with additionally sampled examples to containbatch_size
examples.If
drop_remainder
is True, the final batch is dropped if it contains fewer thanbatch_size
examples. This may result in examples being skipped whennum_epochs=1
.
print('# batches w/ drop_remainder=False')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None))))
print('# batches w/ drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None, drop_remainder=True))))
# batches w/ drop_remainder=False
3
# batches w/ drop_remainder=True
2
If num_steps
is set and num_steps
is None, exactly this many batches are
produced. drop_remainder
has no effect in this case.
print('# batches w/ num_steps set and drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=None, num_steps=3, drop_remainder=True))))
# batches w/ num_steps set and drop_remainder=True
3
If both num_epochs
and num_steps
are set, the fewer number of batches
between the two conditions are produced.
print('# batches w/ num_epochs and num_steps set')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=6))))
# batches w/ num_epochs and num_steps set
3
If reproducible iteration order is desired, a fixed seed
can be used. When
seed
is None, repeated iteration over the same object may produce batches in a
different order.
# Random shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=None))[0])
# Fixed shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=0))[0])
{'x': array([[78, 29, 75, 78, 33, 16, 66, 47],
[ 7, 26, 47, 27, 42, 16, 7, 4]], dtype=int32), 'y': array([[29, 75, 78, 33, 16, 66, 47, 3],
[26, 47, 27, 42, 16, 7, 4, 67]], dtype=int32)}
{'x': array([[16, 14, 4, 67, 47, 16, 48, 67],
[48, 16, 13, 32, 33, 14, 11, 78]], dtype=int32), 'y': array([[14, 4, 67, 47, 16, 48, 67, 84],
[16, 13, 32, 33, 14, 11, 78, 76]], dtype=int32)}
Preprocessing
Going from the raw features to features in a batch of examples ready for use in training/evalution often requires a few steps of preprocessing. Sometimes, we also want to add new preprocessing to an existing FederatedData
.
Before going into the details of preprocessing, please note once again that all dataset related operations should be implemented in standard NumPy, not in JAX.
Preprocessing a batch of examples
The easiest way to add an additional preprocessing step is by appending a function that transforms a batch of examples to a FederatedData
’s list of preprocessing transformations on batches.
Below, we add a new z
feature that stores the parity of y
for our Shakespeare examples.
# A preprocessing function should return a new dict of examples instead of
# modifying its input.
def parity_feature(examples):
return {'z': examples['y'] % 2, **examples}
# preprocess_batch returns a new FederatedData object that has one more
# preprocessing step at the very end than the original.
train_fd_z = train_fd.preprocess_batch(parity_feature)
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
next(iter(train_fd_z.get_client(client_id).padded_batch(batch_size=4)))
{'z': array([[1, 1, 0, 1, 1, 1, 1, 0],
[0, 1, 0, 1, 0, 1, 0, 0],
[0, 1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 0]], dtype=int32),
'x': array([[ 1, 55, 67, 84, 67, 47, 7, 67],
[48, 16, 13, 32, 33, 14, 11, 78],
[76, 78, 33, 19, 16, 66, 47, 3],
[16, 27, 67, 23, 26, 47, 3, 27]], dtype=int32),
'y': array([[55, 67, 84, 67, 47, 7, 67, 48],
[16, 13, 32, 33, 14, 11, 78, 76],
[78, 33, 19, 16, 66, 47, 3, 16],
[27, 67, 23, 26, 47, 3, 27, 16]], dtype=int32),
'__mask__': array([ True, True, True, True])}
Preprocessing at the client level
Sometimes we also need to do some preprocessing for the entire client dataset. For example, in the Shakespeare dataset, the raw features are just text strings, so we need to turn them into a NumPy array of chunks of integer labels just to be able to meaningfully batch them at all.
In most circumstances, adding a preprocessing step at the client level is unnecessary, and should be avoided, because the new preprocessing step added by preprocess_client
is inserted into the middle of a chain of steps, just before all the batch level preprocessing steps registered by preprocess_batch
. If not done carefully, a custom client level preprocessing can easily break the preprocessing chain.
Nevertheless, here’s an example of client level processing for certain rare cases.
# Load unpreprocessed data.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n with IMOGEN\n'],
dtype=object)}
The actual client level preprocessing in the Shakespeare dataset is a bit involved, so let’s do something simpler: we shall join all the snippets, and then split and pad the integer byte values into bounded length sequences.
# We don't actually need client_id, but `FederatedData` supplies it so that
# different processing based on clients can be done.
def truncate_and_cast(client_id, examples, max_length=10):
labels = list(b''.join(examples['snippets']))
num_sequences = (len(labels) + max_length - 1) // max_length
padded = np.zeros((num_sequences, max_length), dtype=np.int32)
for i in range(num_sequences):
chars = labels[i * max_length:(i + 1) * max_length]
padded[i, :len(chars)] = chars
return {'snippets': padded}
partial_fd = raw_fd.preprocess_client(truncate_and_cast)
partial_fd.get_client(client_id).all_examples()
{'snippets': array([[ 82, 101, 45, 101, 110, 116, 101, 114, 32, 80],
[ 79, 83, 84, 72, 85, 77, 85, 83, 44, 32],
[ 97, 110, 100, 32, 115, 101, 99, 111, 110, 100],
[115, 32, 116, 104, 101, 32, 66, 114, 105, 116],
[111, 110, 115, 59, 32, 116, 104, 101, 121, 32],
[114, 101, 115, 99, 117, 101, 10, 67, 89, 77],
[ 66, 69, 76, 73, 78, 69, 44, 32, 97, 110],
[100, 32, 101, 120, 101, 117, 110, 116, 46, 32],
[ 84, 104, 101, 110, 32, 114, 101, 45, 101, 110],
[116, 101, 114, 32, 76, 85, 67, 73, 85, 83],
[ 32, 97, 110, 100, 32, 73, 65, 67, 72, 73],
[ 77, 79, 44, 10, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 119, 105, 116, 104, 32],
[ 73, 77, 79, 71, 69, 78, 10, 0, 0, 0]], dtype=int32)}
Now, we can add another batch level preprocessor to produce x
and y
labels.
def snippets_to_xy(examples):
snippets = examples['snippets']
return {'x': snippets[:, :-1], 'y': snippets[:, 1:]}
partial_fd.preprocess_batch(snippets_to_xy).get_client(client_id).all_examples()
{'x': array([[ 82, 101, 45, 101, 110, 116, 101, 114, 32],
[ 79, 83, 84, 72, 85, 77, 85, 83, 44],
[ 97, 110, 100, 32, 115, 101, 99, 111, 110],
[115, 32, 116, 104, 101, 32, 66, 114, 105],
[111, 110, 115, 59, 32, 116, 104, 101, 121],
[114, 101, 115, 99, 117, 101, 10, 67, 89],
[ 66, 69, 76, 73, 78, 69, 44, 32, 97],
[100, 32, 101, 120, 101, 117, 110, 116, 46],
[ 84, 104, 101, 110, 32, 114, 101, 45, 101],
[116, 101, 114, 32, 76, 85, 67, 73, 85],
[ 32, 97, 110, 100, 32, 73, 65, 67, 72],
[ 77, 79, 44, 10, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 119, 105, 116, 104],
[ 73, 77, 79, 71, 69, 78, 10, 0, 0]], dtype=int32),
'y': array([[101, 45, 101, 110, 116, 101, 114, 32, 80],
[ 83, 84, 72, 85, 77, 85, 83, 44, 32],
[110, 100, 32, 115, 101, 99, 111, 110, 100],
[ 32, 116, 104, 101, 32, 66, 114, 105, 116],
[110, 115, 59, 32, 116, 104, 101, 121, 32],
[101, 115, 99, 117, 101, 10, 67, 89, 77],
[ 69, 76, 73, 78, 69, 44, 32, 97, 110],
[ 32, 101, 120, 101, 117, 110, 116, 46, 32],
[104, 101, 110, 32, 114, 101, 45, 101, 110],
[101, 114, 32, 76, 85, 67, 73, 85, 83],
[ 97, 110, 100, 32, 73, 65, 67, 72, 73],
[ 79, 44, 10, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 32, 32, 32, 32, 32],
[ 32, 32, 32, 32, 119, 105, 116, 104, 32],
[ 77, 79, 71, 69, 78, 10, 0, 0, 0]], dtype=int32)}
In memory federated datasets
In many scenarios, it is desirable to create a small custom federated dataset from a collection of NumPy arrays for quick experimentation. FedJAX provides fedjax.InMemoryFederatedData
to create small custom datasets. fedjax.InMemoryFederatedData
takes a dictionary of numpy examples keyed by client id and creates a fedjax.FederatedData
that is compatible with the rest of the library. We illustrate this below with a simple example.
# Obtain MNIST dataset from tensorflow and convert to numpy format.
import tensorflow_datasets as tfds
(ds_train, ds_test) = tfds.load('mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=False)
features, labels = list(ds_train.batch(60000).as_numpy_iterator())[0]
print('features shape', features.shape)
print('labels shape', labels.shape)
# Randomly split dataset into 100 clients and load them to a dictionary.
indices = np.random.randint(100, size=60000)
client_id_to_dataset_mapping = {}
for i in range(100):
client_id_to_dataset_mapping[i] = {'x': features[indices==i, :, : , :],
'y': labels[indices==i]}
# Create fedjax.InMemoryDataset.
iid_mnist_federated_data = fedjax.InMemoryFederatedData(
client_id_to_dataset_mapping)
print('number of clients in iid_mnist_data',
iid_mnist_federated_data.num_clients())
features shape (60000, 28, 28, 1)
labels shape (60000,)
number of clients in iid_mnist_data 100
Recap
In this tutorial, we have covered the following:
Using
fedjax.FederatedData
.Different ways of batching client datasets.
Different ways of processing client datasets.
Creating small custom federated datasets.
Working with models in FedJAX
In this chapter, we will learn about fedjax.Model
. This notebook assumes you already have finished the “Datasets” chapter. We first overview centralized training and evaluation with fedjax.Model
and then describe how to add new neural architectures and specify additional evaluation metrics.
# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
import itertools
import jax
import jax.numpy as jnp
from jax.example_libraries import stax
import fedjax
Centralized training & evaluation with fedjax.Model
Most federated learning algorithms are built upon common components from standard centralized learning. fedjax.Model
holds these common components. In centralized learning, we are mostly concerned with two tasks:
Training: We want to optimize our model parameters on the training dataset.
Evaluation: We want to know the values of evaluation metrics (e.g. accuracy) of the current model parameters on a test dataset.
Let’s first see how we can carry out these two tasks on the EMNIST dataset with fedjax.Model
.
# Load train/test splits of the EMNIST dataset.
train, test = fedjax.datasets.emnist.load_data()
# As a start, let's simply use a logistic regression model.
model = fedjax.models.emnist.create_logistic_model()
Random initialization, the JAX way
To start training, we need some randomly initialized parameters. In JAX, pseudo random number generation works slightly differently. For now, it is sufficient to know we call jax.random.PRNGKey()
to seed the random number generator. JAX has a detailed introduction on this topic, if you are interested.
To create the initial model parameters, we simply call fedjax.Model.init()
with a PRNGKey
.
params_rng = jax.random.PRNGKey(0)
params = model.init(params_rng)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Here are our initial model parameters. With the same PRNGKey
, we will always get the same random initialization. There are 2 parameters in our model, the weights w
, and the bias b
. They are organized into a FlapMapping
, but in general any PyTree can be used to store model parameters.
params
FlatMapping({
'linear': FlatMapping({
'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0.], dtype=float32),
'w': DeviceArray([[-0.04067196, 0.02348138, -0.0214883 , ..., 0.01055492,
-0.06988288, -0.02952586],
[-0.03985253, -0.03804361, 0.01401524, ..., 0.02281437,
-0.01771905, 0.06676884],
[ 0.00098182, -0.00844628, 0.01303554, ..., -0.05299249,
0.01777634, -0.0006488 ],
...,
[-0.05691862, 0.05192501, 0.01588603, ..., 0.0157204 ,
-0.01854135, 0.00297953],
[ 0.01680706, 0.05579231, 0.0459589 , ..., 0.01990358,
-0.01944044, -0.01710149],
[-0.00880739, 0.04229043, 0.00998938, ..., -0.00633441,
-0.04824542, 0.01395545]], dtype=float32),
}),
})
Evaluating model parameters
Before we start training, let’s first see how our initial parameters fare on the train and test sets. Unsurprisingly, they do not do very well. We evaluate using the fedjax.evaluate_model()
which takes in model, parameters, and datasets which are batched. As noted in the dataset tutorial, we batch using
fedjax.padded_batch_federated_data()
for efficiency. fedjax.padded_batch_federated_data()
is very similar to fedjax.ClientDataset.padded_batch()
but operates over the entire federated dataset.
# We select first 16 batches using itertools.islice.
batched_test_data = list(itertools.islice(
fedjax.padded_batch_federated_data(test, batch_size=128), 16))
batched_train_data = list(itertools.islice(
fedjax.padded_batch_federated_data(train, batch_size=128), 16))
print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.01757812, dtype=float32), 'loss': DeviceArray(4.1253214, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.02490234, dtype=float32), 'loss': DeviceArray(4.116228, dtype=float32)}
How does our model know what evaluation metrics to report? It is simply specified in the eval_metrics
field. We will discuss evaluation metrics in more detail later.
model.eval_metrics
{'accuracy': Accuracy(target_key='y', pred_key=None),
'loss': CrossEntropyLoss(target_key='y', pred_key=None)}
Since fedjax.evaluate_model()
simply takes a stream of batches, we can also use it to evaluate multiple clients.
for client_id, dataset in itertools.islice(test.clients(), 4):
print(
client_id,
fedjax.evaluate_model(model, params,
dataset.padded_batch(batch_size=128)))
b'002d084c082b8586:f0185_23' {'accuracy': DeviceArray(0.05, dtype=float32), 'loss': DeviceArray(4.1247168, dtype=float32)}
b'005fdad281234bc0:f0151_02' {'accuracy': DeviceArray(0.09375, dtype=float32), 'loss': DeviceArray(4.093891, dtype=float32)}
b'014c177da5b15a39:f1565_04' {'accuracy': DeviceArray(0., dtype=float32), 'loss': DeviceArray(4.127692, dtype=float32)}
b'0156df0c34a25944:f3772_10' {'accuracy': DeviceArray(0.05263158, dtype=float32), 'loss': DeviceArray(4.1521378, dtype=float32)}
The training objective
To train our model, we need two things: the objective function to minimize and an optimizer.
fedjax.Model
contains two functions that can be used to arrive at the training objective:
apply_for_train(params, batch_example, rng)
takes the current model parameters, a batch of examples, and aPRNGKey
, and returns some output.train_loss(batch_example, train_output)
translates the output ofapply_for_train()
into a vector of per-example loss values.
In our example model, apply_for_train()
produces a score for each class and train_loss()
is simply the cross entropy loss. apply_for_train()
in this case does not make use of a PRNGKey
, so we can pass None
instead for convenience. A different apply_for_train()
might actually make use of the PRNGKey
, for tasks such as dropout.
# train_batches is an infinite stream of shuffled batches of examples.
def train_batches():
return fedjax.shuffle_repeat_batch_federated_data(
train,
batch_size=8,
client_buffer_size=16,
example_buffer_size=1024,
seed=0)
# We obtain the first batch by using the `next` function.
example = next(train_batches())
output = model.apply_for_train(params, example, None)
per_example_loss = model.train_loss(example, output)
output.shape, per_example_loss
((8, 62), DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
4.209843 , 4.060939 , 4.19899 ], dtype=float32))
Note that the output
is per example predictions and has shape (8, 62), where 8 is the batch size and 62 is the number of classes. Alternatively, we can use model_per_example_loss()
to get a function that gives us the same result. model_per_example_loss()
is a convenience function that does exactly what we just did.
per_example_loss_fn = fedjax.model_per_example_loss(model)
per_example_loss_fn(params, example, None)
DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
4.209843 , 4.060939 , 4.19899 ], dtype=float32)
The training objective is a scalar, so why does train_loss()
return a vector of per-example loss values? First of all, the training objective in most cases is just the average of the per-example loss values, so arriving at the final training objective isn’t hard. Moreover, in certain algorithms, we not only use the train loss over a single batch of examples for a stochastic training step, but also need to estimate the average train loss over an entire (client) dataset. Having the per-example loss values there is instrumental in obtaining the correct estimate when the batch sizes may vary.
def train_objective(params, example):
return jnp.mean(per_example_loss_fn(params, example, None))
train_objective(params, example)
DeviceArray(4.0680556, dtype=float32)
Optimizers
With the training objective at hand, we just need an optimizer to find some good model parameters that minimize it.
There are many optimizer implementations in JAX out there, but FedJAX doesn’t force one choice over any other. Instead, FedJAX provides a simple fedjax.optimizers.Optimizer
interface so a new optimizer implementation can be wrapped. For convenience, FedJAX provides some common optimizers wrapped from optax.
optimizer = fedjax.optimizers.adam(1e-3)
An optimizer is simply a pair of two functions:
init(params)
returns the initial optimizer state, such as initial values for accumulators of gradients.apply(grads, opt_state, params)
applies the gradients to update the current optimizer state and model parameters.
Instead of modifying opt_state
or params
, apply()
returns a new pair of optimizer state and model parameters. In JAX, it is common to express computations in this stateless/mutation free style, often referred to as functional programming, or pure functions. The pureness of functions is crucial to many features in JAX, so it is always good practice to write functions that do not modify its inputs. You have probably also noticed that all the functions of fedjax.Model
we have seen so far do not modify the model object itself (for example, init()
returns model parameters instead of setting some attribute of model
; apply_for_train()
takes model parameters as an input argument, instead of getting it from model
). FedJAX does this to keep all functions pure.
However, in the top level training loop, it is fine to mutate states since we are not in a function that may be transformed by JAX. Let’s run our first training step, which resulted in a slight decrease in objective on the same batch of examples.
To obtain the gradients, we use jax.grad()
which returns the gradient function. More details about jax.grad()
can be found from the JAX documentation.
opt_state = optimizer.init(params)
grads = jax.grad(train_objective)(params, example)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(4.0080366, dtype=float32)
Instead of using jax.grad()
directly, FedJAX also provides a convenient fedjax.model_grad()
which computes the gradient of a model with respect to the averaged fedjax.model_per_example_loss()
.
model_grads = fedjax.model_grad(model)(params, example, None)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(3.9482572, dtype=float32)
Let’s wrap everything into a single JIT compiled function and train a few more steps, and evaluate again.
@jax.jit
def train_step(example, opt_state, params):
grads = jax.grad(train_objective)(params, example)
return optimizer.apply(grads, opt_state, params)
for example in itertools.islice(train_batches(), 5000):
opt_state, params = train_step(example, opt_state, params)
print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.6152344, dtype=float32), 'loss': DeviceArray(1.5562292, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.59765625, dtype=float32), 'loss': DeviceArray(1.6278805, dtype=float32)}
Building a custom model
fedjax.Model
was designed with customization in mind. We have already seen how to switch to a different training loss. In this section, we will discuss how the rest of a fedjax.Model
can be customized.
Training loss
Because train_loss()
is separate from apply_for_train()
, it is easy to switch to a different loss function.
def hinge_loss(example, output):
label = example['y']
num_classes = output.shape[-1]
mask = jax.nn.one_hot(label, num_classes)
label_score = jnp.sum(output * mask, axis=-1)
best_score = jnp.max(output + 1 - mask, axis=-1)
return best_score - label_score
hinge_model = model.replace(train_loss=hinge_loss)
fedjax.model_per_example_loss(hinge_model)(params, example, None)
DeviceArray([4.306656 , 0. , 0. , 0.4375435 , 0.96986485,
0. , 0.3052401 , 1.3918507 ], dtype=float32)
Evaluation metrics
We have already seen that the eval_metrics
field of a fedjax.Model
tells the model what metrics to evaluate. eval_metrics
is a mapping from metric names to fedjax.metrics.Metric
objects. A fedjax.metrics.Metric
object tells us how to calculate a metric’s value from multiple batches of examples. Like fedjax.Model
, a fedjax.metrics.Metric
is stateless.
To customize the metrics to evaluate on, or what names to give to each, simply specify a different mapping.
only_accuracy = model.replace(
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
fedjax.evaluate_model(only_accuracy, params, batched_test_data)
{'accuracy': DeviceArray(0.6152344, dtype=float32)}
There are already some concrete Metric
s in fedjax.metrics
. It is also easy to implement a new one. You can read more about how to implement a Metric
in its own introduction.
The bit of fedjax.Model
that is directly relevant to evaluation is apply_for_eval()
. The relation between apply_for_eval()
and an evaluation metric is similar to that between apply_for_train()
and train_loss()
: apply_for_eval(params, example)
takes the model parameters and a batch of examples (notice there is no randomness in evaluation so we don’t need a PRNGKey
), and produces some prediction that evaluation metrics can consume. In our example, the outputs from apply_for_eval()
and apply_for_train()
are identical, but they don’t have to be.
jnp.all(
model.apply_for_train(params, example, None) == model.apply_for_eval(
params, example))
DeviceArray(True, dtype=bool)
What apply_for_eval()
needs to produce really just depends on what evaluation fedjax.metrics.Metric
s will be used. In our case, we are using fedjax.metrics.Accuracy
, and fedjax.metrics.CrossEntropyLoss
. They are similar in their requirements on the inputs:
They both need to know the true label from the
example
, using atarget_key
that defaults to"y"
.They both need to know the predicted scores from
apply_for_eval()
, customizable aspred_key
. Ifpred_key
is None,apply_for_eval()
should return just a vector of per-class scores; otherwisepred_key
can be a string key, andapply_for_eval()
should return a mapping (e.g.dict
) that maps the key to a vector of per-class scores.
fedjax.metrics.Accuracy()
Accuracy(target_key='y', pred_key=None)
Neural network architectures
We have now covered all five parts of a fedjax.Model
, namely init()
, apply_for_train()
, apply_for_eval()
, train_loss()
, and eval_metrics
. train_loss()
and eval_metrics
are easy to customize since they are mostly agnostic to the actual neural network architecture of the model. init()
, apply_for_train()
, and apply_for_eval()
on the other hand, are closely related.
In principle, as long as these three functions meet the interface we have seen so far, they can be used to build a custom model. Let’s try to build a model that uses multi-layer perceptron and hinge loss.
def cross_entropy_loss(example, output):
label = example['y']
num_classes = output.shape[-1]
mask = jax.nn.one_hot(label, num_classes)
return -jnp.sum(jax.nn.log_softmax(output) * mask, axis=-1)
def mlp_model(num_input_units, num_units, num_classes):
def mlp_init(rng):
w0_rng, w1_rng = jax.random.split(rng)
w0 = jax.random.uniform(w0_rng, [num_input_units, num_units])
b0 = jnp.zeros([num_units])
w1 = jax.random.uniform(w1_rng, [num_units, num_classes])
b1 = jnp.zeros([num_classes])
return w0, b0, w1, b1
def mlp_apply(params, batch, rng=None):
w0, b0, w1, b1 = params
x = batch['x']
batch_size = x.shape[0]
h = jax.nn.relu(x.reshape([batch_size, -1]) @ w0 + b0)
return h @ w1 + b1
return fedjax.Model(
init=mlp_init,
apply_for_train=mlp_apply,
apply_for_eval=mlp_apply,
train_loss=cross_entropy_loss,
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
# There are 28*28 input pixels, and 62 classes in EMNIST.
mlp = mlp_model(28 * 28, 128, 62)
@jax.jit
def mlp_train_step(example, opt_state, params):
@jax.grad
def grad_fn(params, example):
return jnp.mean(fedjax.model_per_example_loss(mlp)(params, example, None))
grads = grad_fn(params, example)
return optimizer.apply(grads, opt_state, params)
params = mlp.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
fedjax.evaluate_model(mlp, params, batched_test_data))
for example in itertools.islice(train_batches(), 5000):
opt_state, params = mlp_train_step(example, opt_state, params)
print('eval_test after training:',
fedjax.evaluate_model(mlp, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.05078125, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.4951172, dtype=float32)}
While writing custom neural network architectures from scratch is possible, most of the time, it is much more convenient to use a neural network library such as Haiku or jax.example_libraries.stax
. The two functions fedjax.create_model_from_haiku
and fedjax.create_model_from_stax
can convert a neural network expressed in the respective framework into a fedjax.Model
. Let’s build a convolutional network using jax.example_libraries.stax
this time.
def stax_cnn_model(input_shape, num_classes):
stax_init, stax_apply = stax.serial(
stax.Conv(
out_chan=64, filter_shape=(3, 3), strides=(1, 1), padding='SAME'),
stax.Relu,
stax.Flatten,
stax.Dense(256),
stax.Relu,
stax.Dense(num_classes),
)
return fedjax.create_model_from_stax(
stax_init=stax_init,
stax_apply=stax_apply,
sample_shape=input_shape,
train_loss=cross_entropy_loss,
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
stax_cnn = stax_cnn_model([-1, 28, 28, 1], 62)
@jax.jit
def stax_cnn_train_step(example, opt_state, params):
@jax.grad
def grad_fn(params, example):
return jnp.mean(
fedjax.model_per_example_loss(stax_cnn)(params, example, None))
grads = grad_fn(params, example)
return optimizer.apply(grads, opt_state, params)
params = stax_cnn.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
fedjax.evaluate_model(stax_cnn, params, batched_test_data))
for example in itertools.islice(train_batches(), 1000):
opt_state, params = stax_cnn_train_step(example, opt_state, params)
print('eval_test after training:',
fedjax.evaluate_model(stax_cnn, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.03076172, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.72558594, dtype=float32)}
Recap
In this chapter, we have covered the following:
Components of
fedjax.Model
:init()
,apply_for_train()
,apply_for_eval()
,train_loss()
, andeval_metrics
.Optimizers in
fedjax.optimizers
.Standard centralized learning with a
fedjax.Model
.Specifying evaluation metrics in
eval_metrics
.Building a custom
fedjax.Model
.
Federated learning algorithms
This tutorial introduces algorithms for federated learning in FedJAX. By completing this tutorial, we’ll learn how to write clear and efficient algorithms that follow best practices. This tutorial assumes that we have finished the tutorials on datasets and models.
In order to keep the code pseudo-code-like, we avoid using jax primitives directly while writing algorithms, with the notable exceptions of the jax.random
and jax.tree_util
libraries. Since lower-level functions that are described in the model tutorial, such as fedjax.optimizers
, model.grad
, are all JIT compiled already, the algorithms will still be efficient.
# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
import jax
import jax.numpy as jnp
import numpy as np
import fedjax
# We only use TensorFlow for datasets, so we restrict it to CPU only to avoid
# issues with certain ops not being available on GPU/TPU.
fedjax.training.set_tf_cpu_only()
Introduction
A federated algorithm trains a machine learning model over decentralized data distributed over several clients. At a high level, the server first randomly initializes the model parameters and other learning components. Then at each round the following happens:
Client selection: The server selects a few clients at each round, typically at random.
The server transmits the model parameters and other necessary components to the selected clients.
Client update: The clients update the model parameters using a subroutine, which typically involves a few epochs of SGD on their local examples.
The clients transmit the updates to the server.
Server aggregation: The server combines the clients’ updates to produce new model parameters.
A pseudo-code for a common federated learning algorithm can be found in Algorithm 1 in Kairouz et al. (2020). Since FedJAX focuses on federated simulation and there is no actual transmission between clients and the server, we only focus on steps 1, 3, and 5, and ignore steps 2 and 4. Before we describe each of the modules, we will first describe how to use algorithms that are implemented in FedJAX.
Federated algorithm overview
We implement federated learning algorithms using the fedjax.FederatedAlgorithm
interface. The fedjax.FederatedAlgorithm
interface has two functions init
and apply
. Broadly, our implementation has three parts.
ServerState
: This contains all the information available at the server at any given round. It includes model parameters and can also include other parameters that are used during optimization. At every round, a subset ofServerState
is passed to the clients for federated learning.ServerState
is also used in checkpointing and evaluation. Hence it is crucial that all the parameters that are modified during the course of federated learning are stored as part of theServerState
. Do not store mutable parameters as part offedjax.FederatedAlgorithm
.init
: Initializes the server state.apply
: Takes theServerState
and a set of client_ids, corresponding datasets, and random keys and returns a newServerState
along with any information we need from the clients in the form ofclient_diagnostics
.
We demonstrate fedjax.FederatedAlgorithm
using Federated Averaging (FedAvg) and the emnist dataset. We first initialize the model, datasets and the federated algorithm.
train, test = fedjax.datasets.emnist.load_data(only_digits=False)
model = fedjax.models.emnist.create_conv_model(only_digits=False)
rng = jax.random.PRNGKey(0)
init_params = model.init(rng)
# Federated algorithm requires a gradient function, client optimizer,
# server optimizers, and hyperparameters for batching at the client level.
grad_fn = fedjax.model_grad(model)
client_optimizer = fedjax.optimizers.sgd(0.1)
server_optimizer = fedjax.optimizers.sgd(1.0)
batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=10)
fed_alg = fedjax.algorithms.fed_avg.federated_averaging(grad_fn,
client_optimizer,
server_optimizer,
batch_hparams)
Note that similar to the rest of the library, we only pass on the necessary functions and parameters to the federated algorithm. Hence, to initialize the federated algorithm, we only passed the grad_fn
and did not pass the entire model
. With this, we now initialize the server state.
init_server_state = fed_alg.init(init_params)
To run the federated algorithm, we pass the server state and client data to the apply
function. For this end, we pass client data as a tuple of client id, client data, and the random keys. Adding client ids and random keys has multiple advantages. Firstly client client ids allows to track client diagnostics and would be helpful in debugging. Passing random keys would ensure deterministic execution and allow repeatability. Furthermore, as we discuss later it would help us with fast implementations. We first format the data in this necessary format and then run one round of federated learning.
# Select 5 client_ids and their data
client_ids = list(train.client_ids())[:5]
clients_ids_and_data = list(train.get_clients(client_ids))
client_inputs = []
for i in range(5):
rng, use_rng = jax.random.split(rng)
client_id, client_data = clients_ids_and_data[i]
client_inputs.append((client_id, client_data, use_rng))
updated_server_state, client_diagnostics = fed_alg.apply(init_server_state,
client_inputs)
# Prints the l2 norm of gradients as part of client_diagnostics.
print(client_diagnostics)
{b'002d084c082b8586:f0185_23': {'delta_l2_norm': DeviceArray(1.9278834, dtype=float32)}, b'005fdad281234bc0:f0151_02': {'delta_l2_norm': DeviceArray(1.8239512, dtype=float32)}, b'014c177da5b15a39:f1565_04': {'delta_l2_norm': DeviceArray(1.6514685, dtype=float32)}, b'0156df0c34a25944:f3772_10': {'delta_l2_norm': DeviceArray(1.5863262, dtype=float32)}, b'01725f8a648ceeb6:f3408_47': {'delta_l2_norm': DeviceArray(1.613201, dtype=float32)}}
As we see above, the client statistics provide the delta_l2_norm
of the gradients for each client, which can be potentially used for debugging purposes.
Writing federated algorithms
With this background on how to use existing implementations, we are now going to describe how to write your own federated algorithms in FedJAX. As discussed above, this involves three steps:
Client selection
Client update
Server aggregation
Client selection
At each round of federated learning, typically clients are sampled uniformly at random. This can be done using numpy as follows.
all_client_ids = list(train.client_ids())
print("Total number of client ids: ", len(all_client_ids))
sampled_client_ids = np.random.choice(all_client_ids, size=2, replace=False)
print("Sampled client ids: ", sampled_client_ids)
Total number of client ids: 3400
Sampled client ids: [b'3abf53413107a36f:f3901_45' b'7b15a2a43e2c5097:f0452_37']
However, the above code is not desirable due to the following reasons:
For reproducibility, it is desirable to have a fixed seed just for sampling clients.
Across rounds, different clients need to be sampled.
For I/O efficiency reasons, it might be better to do an approximately uniform sampling, where clients whose data is stored together are sampled together.
Federated algorithms typically require additional randomness for batching, or dropout that needs to be sent to clients.
To incorporate these features, FedJAX provides a few client samplers.
fedjax.client_samplers.UniformShuffledClientSampler
fedjax.client_samplers.UniformGetClientSampler
fedjax.client_samplers.UniformShuffledClientSampler
is preferred for efficiency reasons, but if we need to sample clients truly randomly, fedjax.client_samplers.UniformGetClientSampler
can be used.
Both of them have a sample
function that returns a list of client_ids, client_data, and client_rng.
efficient_sampler = fedjax.client_samplers.UniformShuffledClientSampler(
train.shuffled_clients(buffer_size=100), num_clients=2)
print("Sampling from the efficient sampler.")
for round in range(3):
sampled_clients_with_data = efficient_sampler.sample()
for client_id, client_data, client_rng in sampled_clients_with_data:
print(round, client_id)
perfect_uniform_sampler = fedjax.client_samplers.UniformGetClientSampler(
train, num_clients=2, seed=1)
print("Sampling from the perfect uniform sampler.")
for round in range(3):
sampled_clients_with_data = perfect_uniform_sampler.sample()
for client_id, client_data, client_rng in sampled_clients_with_data:
print(round, client_id)
Sampling from the efficient sampler.
0 b'1ec29e39b10521aa:f3848_28'
0 b'0156df0c34a25944:f3772_10'
1 b'0e9e55d97351b8a6:f1424_45'
1 b'1cb4c2b15c501e91:f0936_34'
2 b'1ef43f1933de69ae:f2086_25'
2 b'0a544c86e9731fc1:f0629_39'
Sampling from the perfect uniform sampler.
0 b'0eb0b8eb6ab0fcd6:f0174_06'
0 b'2b766db84ee57ba6:f2380_76'
1 b'b268b64fea71f3a7:f0766_01'
1 b'786d98ea95e5a59b:f1729_47'
2 b'35b6c7951d56d353:f3574_30'
2 b'98355b0cd97cdbb8:f3990_02'
Client update
After selecting the clients, the next step would be running a model update step in the clients. Typically this is done by running a few epochs of SGD. We only pass parts of the algorithm that are necessary for the client update.
The client update typically requires a set of parameters from the server (init_params
in this example), the client dataset, and a source of randomness (rng
). The randomness can be used for dropout or other model update steps. Finally, instead of passing the entire model to the client update, since our code only depends on the gradient function, we pass grad_fn
to client_update.
def client_update(init_params, client_dataset, client_rng, grad_fn):
opt_state = client_optimizer.init(init_params)
params = init_params
for batch in client_dataset.shuffle_repeat_batch(batch_size=10):
client_rng, use_rng = jax.random.split(client_rng)
grads = grad_fn(params, batch, use_rng)
opt_state, params = client_optimizer.apply(grads, opt_state, params)
delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
init_params, params)
return delta_params, len(client_dataset)
client_sampler = fedjax.client_samplers.UniformGetClientSampler(
train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
for client_id, client_data, client_rng in sampled_clients_with_data:
delta_params, num_samples = client_update(init_params,client_data,
client_rng, grad_fn)
print(client_id, num_samples, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' 348 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' 152 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
Server aggregation
The outputs of the clients are typically aggregated by computing the weighted mean of the updates, where the weight is the number of client examples. This can be easily done by using the fedjax.tree_util.tree_mean
function.
sampled_clients_with_data = client_sampler.sample()
client_updates = []
for client_id, client_data, client_rng in sampled_clients_with_data:
delta_params, num_samples = client_update(init_params, client_data,
client_rng, grad_fn)
client_updates.append((delta_params, num_samples))
updated_output = fedjax.tree_util.tree_mean(client_updates)
print(updated_output.keys())
KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
Combing the above steps gives the FedAvg algorithm, which can be found in the example FedJAX implementation of FedAvg..
Efficient implementation
The above implementation would be efficient enough for running on single machines. However, JAX provides primitives such as jax.pmap
and jax.vmap
for efficient parallelization across multiple accelerators. FedJAX provides support for them in federated learning by distributing client computation across several accelerators.
To take advantage of the faster implementation, we need to implement client_update
in a specific format. It has three functions:
client_init
client_step
client_final
client_init
This function takes the inputs from the server and outputs a client_step_state
which will be passed in between client steps. It is desirable for the client_step_state
to be a dictionary. In this example, it just copies the parameters, optimizer_state and the current state of client randomness.
We can think of the inputs from the server as “shared inputs” that are shared across all clients and the client_step_state
as client-specific inputs that are separate per client.
def client_init(server_params, client_rng):
opt_state = client_optimizer.init(server_params)
client_step_state = {
'params': server_params,
'opt_state': opt_state,
'rng': client_rng,
}
return client_step_state
client_step
client_step
takes the current client_step_state
and a batch of examples and updates the client_step_state
. In this example, we run one step of SGD using the batch of examples and update client_step_state
to reflect the new parameters, optimization state, and randomness.
def client_step(client_step_state, batch):
rng, use_rng = jax.random.split(client_step_state['rng'])
grads = grad_fn(client_step_state['params'], batch, use_rng)
opt_state, params = client_optimizer.apply(grads,
client_step_state['opt_state'],
client_step_state['params'])
next_client_step_state = {
'params': params,
'opt_state': opt_state,
'rng': rng,
}
return next_client_step_state
client_final
client_final
modifies the final client_step_state
and returns the desired parameters. In this example, we compute the difference between the initial parameters and the final updated parameters in the client_final
function.
def client_final(server_params, client_step_state):
delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
server_params,
client_step_state['params'])
return delta_params
fedjax.for_each_client
Once we have these three functions, we can combine them to create a client_update function using the fedjax.for_each_client
function. fedjax.for_each_client
returns a function that can be used to run client updates. The sample usage is below.
for_each_client_update = fedjax.for_each_client(client_init,
client_step,
client_final)
client_sampler = fedjax.client_samplers.UniformGetClientSampler(
train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
batched_clients_data = [
(cid, cds.shuffle_repeat_batch(batch_size=10), crng)
for cid, cds, crng in sampled_clients_with_data
]
for client_id, delta_params in for_each_client_update(init_params,
batched_clients_data):
print(client_id, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
Note that for_each_client_update
requires the client data to be already batched. This is necessary for performance gains while using multiple accelerators. Furthermore, the batch size needs to be the same across all clients.
By default fedjax.for_each_client
selects the standard JIT backend. To enable parallelism with TPUs or for debugging, we can set it using fedjax.set_for_each_client_backend(backend)
, where backend
is either ‘pmap’ or ‘debug’, respectively.
The for each client function can also be used to add some additional step wise results, which can be used for debugging. This requires changing the client_step
function.
def client_step_with_log(client_step_state, batch):
rng, use_rng = jax.random.split(client_step_state['rng'])
grads = grad_fn(client_step_state['params'], batch, use_rng)
opt_state, params = client_optimizer.apply(grads,
client_step_state['opt_state'],
client_step_state['params'])
next_client_step_state = {
'params': params,
'opt_state': opt_state,
'rng': rng,
}
grad_norm = fedjax.tree_util.tree_l2_norm(grads)
return next_client_step_state, grad_norm
for_each_client_update = fedjax.for_each_client(
client_init, client_step_with_log, client_final, with_step_result=True)
for client_id, delta_params, grad_norms in for_each_client_update(
init_params, batched_clients_data):
print(client_id, list(delta_params.keys()))
print(client_id, np.array(grad_norms))
b'0eb0b8eb6ab0fcd6:f0174_06' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'0eb0b8eb6ab0fcd6:f0174_06' [4.599997 3.9414525 4.8312078 4.450683 5.971922 2.4694602 2.6183944
2.1996734 2.7145145 2.9750984 3.0633514 2.4050198 2.612233 2.672571
3.1303792 3.236007 3.3968801 2.986587 2.5775976 2.8625555 3.1062818
4.4250994 2.7431202 3.2192783 2.7670481 3.6075711 3.7296255 5.190155
3.4366677 4.5394745 3.2277424 3.1362765 2.8626535 3.7905648 3.5686817]
b'2b766db84ee57ba6:f2380_76' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'2b766db84ee57ba6:f2380_76' [4.3958793 4.4705005 4.84529 4.4623365 5.3809257 5.4818144 3.0325508
4.1576953 8.666909 2.067883 2.1935403 2.5095372 2.2202325 2.8493588
3.2463503 3.6446893]
Recap
In this tutorial, we have covered the following:
Using exisiting algorithms in
fedjax.algorithms
.Writing new algorithms using
fedjax.FederatedAlgorithm
.Efficient implementation using
fedjax.for_each_client
in the presence of accelerators.
Contributing
Everyone can contribute to FedJAX, and we value everyone’s contributions. There are several ways to contribute, including:
Improving or expanding FedJAX’s documentation
Contributing to FedJAX’s code-base
The FedJAX project follows Google’s Open Source Community Guidelines.
Ways to contribute
We welcome pull requests, in particular for those issues marked with contributions welcome or good first issue.
For other proposals, we ask that you first open a GitHub Issue or Discussion to seek feedback on your planned contribution.
Contributing code using pull requests
We do all of our development using git, so basic knowledge is assumed.
Follow these steps to contribute code:
Fork the FedJAX repository by clicking the Fork button on the repository page. This creates a copy of the FedJAX repository in your own account.
Install a support version Python listed in https://github.com/google/fedjax/blob/main/setup.py.
pip
installing your fork from source. This allows you to modify the code and immediately test it out:git clone https://github.com/YOUR_USERNAME/fedjax cd fedjax pip install -r requirements-test.txt # Installs all testing requirements. pip install -e . # Installs FedJAX from the current directory in editable mode.
Add the FedJAX repo as an upstream remote, so you can use it to sync your changes.
git remote add upstream http://www.github.com/google/fedjax
Create a branch where you will develop from:
git checkout -b name-of-change
And implement your changes using your favorite editor. If you are adding a new algorithm, please add unit tests and an associated binary in the experiments folder. The binary should use the EMNIST dataset with a convolution neural model example and have reasonable default hyperparameters that ideally reproduce results from a published paper. We strongly recommend using
fedjax.for_each_client
in your algorithm implementations for computational efficiency.Make sure the tests pass by running the following command from the top of the repository:
pytest -n auto -q \
-k "not SubsetFederatedDataTest and not SQLiteFederatedDataTest and not ForEachClientPmapTest and not DownloadsTest and not CheckpointTest and not LoggingTest" \
fedjax --ignore=fedjax/legacy/
-q
will reduce verbosity levels, -k
selects/deselects specific tests, and
--ignore=fedjax/legacy/
is used to skip the entire fedjax.legacy module. If
there are errors or failures, you can run those specific tests using the
commands in the next section to see more focused details.
pytest -n auto tests/
If you know the specific test file that covers your changes, you can limit the tests to that; for example:
# Run all tests in algorithms
pytest -n auto fedjax/algorithms
# Run only fedjax/core/metrics_test.py
pytest -n auto fedjax/core/metrics_test.py
Once you are satisfied with your change, create a commit as follows ( how to write a commit message):
git add file1.py file2.py ... git commit -m "Your commit message"
Then sync your code with the main repo:
git fetch upstream git rebase upstream/main
Finally, push your commit on your development branch and create a remote branch in your fork that you can use to create a pull request from:
git push --set-upstream origin name-of-change
Create a pull request from the FedJAX repository and send it for review. Check the FedJAX pull request checklist for considerations when preparing your PR, and consult GitHub Help if you need more information on using pull requests.
FedJAX pull request checklist
As you prepare a FedJAX pull request, here are a few things to keep in mind:
Google contributor license agreement
Contributions to this project must be accompanied by a Google Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to https://cla.developers.google.com/ to see your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you’ve already submitted one (even if it was for a different project), you probably don’t need to do it again. If you’re not certain whether you’ve signed a CLA, you can open your PR and our friendly CI bot will check for you.
Single-change commits and pull requests
A git commit ought to be a self-contained, single change with a descriptive message. This helps with review and with identifying or reverting changes if issues are uncovered later on.
Pull requests typically comprise a single git commit. (In some cases, for
instance for large refactors or internal rewrites, they may contain several.)
In preparing a pull request for review, you may need to squash together
multiple commits. We ask that you do this prior to sending the PR for review if
possible. The git rebase -i
command might be useful to this end.
Linting and Type-checking
Please follow the style guide and check code quality by pylint as stated here https://google.github.io/styleguide/pyguide.html
Full GitHub test suite
Your PR will automatically be run through a full test suite on GitHub CI, which covers a range of Python versions, dependency versions, and configuration options. It’s normal for these tests to turn up failures that you didn’t catch locally; to fix the issues you can push new commits to your branch.
Restricted test suite
Once your PR has been reviewed, a FedJAX maintainer will mark it as Pull Ready
. This
will trigger a larger set of tests, including tests on GPU and TPU backends that are
not available via standard GitHub CI. Detailed results of these tests are not publicly
viweable, but the FedJAX mantainer assigned to your PR will communicate with you regarding
any failures these might uncover; it’s not uncommon, for example, that numerical tests
need different tolerances on TPU than on CPU.
Building from source
First, clone the FedJAX source code:
git clone https://github.com/google/fedjax
cd fedjax
Then install the fedjax
Python package:
pip install -e .
To upgrade to the latest version from GitHub, inside of the repository root, just run
git pull
You shouldn’t have to reinstall fedjax
because pip install -e
sets up
symbolic links from site-packages into the repository.
Running the tests
We created a simple run_tests.sh
script for running tests. See its comments for examples. Before creating a pull
request, we recommend running all the FedJAX tests (i.e. running run_tests.sh
with no arguments) to verify the correctness of a change.
Updating the docs
Install the requirements
pip install -r docs/requirements.txt
Then run
sphinx-autobuild -b html --watch . --open-browser docs docs/_build/html
sphinx-autobuild
will watch for file changes and auto build the HTML for you,
so all you’ll have to do is refresh the page. If you don’t want to use the auto
builder, you can just use:
sphinx-build -b html docs docs/_build/html
and then navigate to docs/_build/html/index.html
in your browser.
How to write code documentation
Our documentation it is written in ReStructuredText for Sphinx. This is a meta-language that is compiled into online documentation. For more details see Sphinx’s documentation.
We also rely heavily on sphinx.ext.autodoc
to convert docstrings in source to rst for Sphinx,
so it would be best to be familiar with its directives.
As a result, our docstrings adhere to a specific syntax that has to be kept in mind. Below we provide some guidelines.
How to use “code font”
When writing code font in a docstring, please use double backticks.
# This returns a ``str`` object.
How to create cross-references/links
It is possible to create cross-references to other classes, functions, and
methods. In the following, obj_typ
is either class
, func
, or meth
.
# First method:
# <obj_type>:`path_to_obj`
# Second method:
# :<obj_type>:`description <path_to_obj>`
You can use the second method if the path_to_obj
is very long.
# Create: a reference to class fedjax.experimental.model.Model.
# :class:`fedjax.experimental.model.Model`
# Create a reference to local function my_func.
# :func:`my_func`
# Create a reference "Module.apply()" to method fedjax.experimental.model.Model.apply_for_train.
# :meth:`Model.apply_for_train <fedjax.experimental.model.Model.apply_for_train>`
To create a hyperlink, use the following syntax:
# Note the double underscore at the end:
# `Link to Google <http://www.google.com>`__
You can also cross reference jax
documentation directly since we’ve added it via
sphinx.ext.intersphinx
in docs/conf.py
# :func:`jax.jit`
# Links to https://jax.readthedocs.io/en/latest/jax.html#jax.jit
How to write math
We’re using sphinx.ext.mathjax
.
Then we use the math directive and role to either inline or block notation.
# Blocked notation
# .. math::
# x + y
# Inline notation :math:`x + y`
Examples
Take a look at docs/fedjax.metrics.rst
to get a good idea of what to do.
Update notebooks
It is easiest to edit the notebooks in Jupyter or in Colab.
To edit notebooks in the Colab interface,
open http://colab.research.google.com and Upload ipynb
from your local repo.
Update it as needed, Run all cells
then Download ipynb
to your local repo.
You may want to test that it executes properly, using sphinx-build
as explained above.
We recommend making changes this way to avoid introducing format errors into the .ipynb
files.
In the future, we may build and re-execute the notebooks as part of the
Read the docs build.
However, for now, we exclude all notebooks from the build due to long durations (downloading dataset files, expensive model training, etc.).
See exclude_patterns
in conf.py
fedjax core
FedJAX API.
Subpackages
fedjax.metrics
A small library for working with evaluation metrics such as accuracy.
Stats
Stat keeps some statistic, along with operations over them. |
|
Statistic for weighted mean calculation. |
|
Statistic for summing values. |
Metrics
Metric is the conceptual metric (like accuracy). |
|
Metric for cross entropy loss. |
|
Metric for accuracy. |
|
Metric for top k accuracy. |
|
Metric for token cross entropy loss for a sequence example. |
|
Metric for total cross entropy loss for a sequence example. |
|
Metric for token accuracy for a sequence example. |
|
Metric for token top k accuracy for a sequence example. |
|
Metric for count of non masked tokens for a sequence example. |
|
Metric for count of non masked sequences. |
|
Metric for truncation rate for a sequence example. |
|
Metric for out-of-vocabulary (OOV) rate for a sequence example. |
|
Metric for length for a sequence example. |
|
Turns a base metric into one that groups results by domain. |
|
Metric for making a Confusion Matrix. |
Miscellaneous
Returns unreduced cross entropy loss. |
|
Evaluates a batch using a metric. |
Quick Overview
To evaluate model predictions, use a Metric
object such as Accuracy
.
We recommend fedjax.core.models.evaluate_model()
in most scenarios,
which runs model prediction, and evaluation, on batches of N examples at a time
for greater computational efficiency:
# Mock out Model.
model = fedjax.Model(
init=lambda _: None, # Unused.
apply_for_train=lambda _, _, _: None, # Unused.
apply_for_eval=lambda _, batch: batch.get('pred'),
train_loss=lambda _, _: None, # Unused.
eval_metrics={'accuracy': metrics.Accuracy()})
params = None # Unused.
batches = [{'y': np.array([1, 0]),
'pred': np.array([[1.2, 0.4], [2.3, 0.1]])},
{'y': np.array([1, 1]),
'pred': np.array([[0.3, 3.2], [2.1, 4.3]])}]
results = fedjax.evaluate_model(model, params, batches)
print(results)
# {'accuracy': 0.75}
A Metric
object has 2 methods:
zero()
: Initial value for accumulating the statistic for this metric.evaluate_example()
: Returns the statistic from evaluating a single example, given the trainingexample
and the modelprediction
.
Most Metric
follow the following convention for convenience:
example
is a dict-like object fromstr
tojnp.ndarray
.prediction
is either a singlejnp.ndarray
, or a dict-like object fromstr
tojnp.ndarray
.
Conceptually, we can also use a simple for loop to evaluate a collection of examples and model predictions:
# By default, the `Accuracy` metric treats `example['y']` as the true label,
# and `prediction` as a single `jnp.ndarray` of class scores.
metric = Accuracy()
stat = metric.zero()
# We are iterating over individual examples, not batches.
for example, prediction in [({'y': jnp.array(1)}, jnp.array([0., 1.])),
({'y': jnp.array(0)}, jnp.array([1., 0.])),
({'y': jnp.array(1)}, jnp.array([1., 0.])),
({'y': jnp.array(0)}, jnp.array([2., 0.]))]:
stat = stat.merge(metric.evaluate_example(example, prediction))
print(stat.result())
# 0.75
In practice, for greater computational efficiency, we run model prediction not
on a single example, but a batch of N examples at a time.
fedjax.core.models.evaluate_model()
provides a simple way to do so.
Under the hood, it calls evaluate_batch()
metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
stat = stat.merge(evaluate_batch(metric, batch_example, batch_prediction))
print(stat.result())
# 0.75
Under the hood
For most users, it is sufficient to know how to use existing Metric
subclasses such as Accuracy
with
fedjax.core.models.evaluate_model()
.
This section is intended for those who would like to write new metrics.
From algebraic structures to Metric
and Stat
There are 2 abstraction in this library, Metric
and Stat
.
Before going into details of these classes, let’s first consider a few abstract
properties related to evaluation metrics, using accuracy as an example.
When evaluating accuracy on a dataset, we wish to know the proportion of examples that are correctly predicted by the model. Because a dataset might be too large to fit into memory, we need to divide the work by partitioning the dataset into smaller subsets, evaluate each separately, and finally somehow combine the results. Assuming the subsets can be of different sizes, although the accuracy value is a single number, we cannot just average the accuracy values from each partition to arrive at the overall accuracy. Instead, we need 2 numbers from each subset:
The number of examples in this subset,
How many of them are correctly predicted.
We call these two numbers from each subset a statistic. The domain (the set of possible values) of the statistic in the case of accuracy is
With the numbers of examples and correct predictions from 2 disjoint subsets, we add the numbers up to get the number of examples and correct predictions for the union of the 2 subsets. We call this operation from 2 statistics into 1 a \(merge\) operation.
Let \(f(S)\) be the function that gives us the statistic from a subset of examples. It is easy to see for two disjoint subsets \(A\) and \(B\) , \(merge(f(A), f(B))\) should be equal to \(f(A ∪ B)\) . If no such \(merge\) exists, we cannot evaluate the dataset by partitioning the work. This requirement alone implies the domain of a statistic, and the \(merge\) operation forms a specific algebraic structure (a commutative monoid).
- \(I := f(empty set)\) is one and the only identity element w.r.t.\(merge\) (i.e. \(merge(I, x) == merge(x, I) == x\) ).
\(merge()\) is commutative and associative.
Further, we can see \(f(S)\) can be defined just knowing two types of values:
\(f(empty set)\), i.e. \(I\) ;
\(f({x})\) for any single example \(x\) .
For any other subset \(S\) , we can derive the value of \(f(S)\) using
these values and \(merge\) . Metric
is simply the \(f(S)\) function
above, defined in 2 corresponding parts:
Metric.zero()
is \(f(empty set)\) .Metric.evaluate_example()
is \(f({x})\) for a single example.
On the other hand, Stat
stores a single statistic, a merge()
method for combining two, and a result()
method for producing the
final metric value.
To implement Accuracy
as a subclass of Metric
, we first need to
know what Stat
to use. In this case, the statistic domain and merge is
implemented by a MeanStat
. A MeanStat
holds two values:
accum
is the weighted sum of values, i.e. the number of correctpredictions in the case of accuracy.weight
is the sum of weights, i.e. the number of examples inthe case of accuracy.
merge()
adds up the respective accum
and
weight
from two MeanStat
objects.
Sometimes, a new Stat
subclass is necessary. In that case, it is very
important to make sure the implementation has a clear definition of the domain,
and the merge()
operation adheres to the properties regarding identity
element, commutativity, and associativity (e.g. if we unknowingly allow pairs of
\((x, 0)\) for \(x != 0\) into the domain of a MeanStat
,
\(merge((x, 0), (a, b))\) will produce a statistic that leads to incorrect
final metric values, i.e. \((a+x)/b\) , instead of \(a/b\) ).
Batching Stat
s
In most cases, the final value of an evaluation is simply a scalar, and the
corresponding statistic is also a tuple of a few scalar values. However, for the
same reason why jax.vmap()
is a lot more efficient than a for loop, it is a
lot more efficient to store multiple Stat
values as a Stat
of arrays,
instead of a list of Stat
objects. Thus instead of a list
[MeanStat(1, 2), MeanStat(3, 4), MeanStat(5, 6)]
(call these “rank 0” Stat
s), we can batch the 3 statitstics as a single
MeanStat(jnp.array([1, 3, 5]), jnp.array([2, 4, 6]))
.
A Stat
object holding a single statistic is a “rank 0” Stat
.
A Stat
object holding a vector of statistics is a “rank 1” Stat
.
Similarly, a Stat
object may also hold a matrix, a 3D array, etc, of statistics.
These are higher rank Stat
s. In most cases, the rank 1 implementation of
merge()
and :meth`~Stat.result` automatically generalizes to higher ranks as
elementwise operations.
In the end, we want just 1 final metric value instead of a length 3 vector, or a
2x2 matrix, of metric values. To finally go back to a single statistic (), we
need to merge()
statistics stored in these arrays. Each Stat
subclass
provides a reduce()
method to do just that. The combination of jax.vmap()
over Metric.evaluate_example()
, and Stat.reduce()
, is how we get an
efficient evaluate_batch()
function (of course, the real evaluate_batch()
is jax.jit()
‘d so that the same jax.vmap()
transformation etc does not
need to happen over and over.:
metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
# Get a batch of statistics as a single Stat object.
batch_stat = jax.vmap(metric.evaluate_example)(batch_example,
batch_prediction)
# Merge the reduced single statistic onto the accumulator.
stat = stat.merge(batch_stat.reduce())
print(stat.result())
# 0.75
Being able to batch Stat
s also allows us to do other interesting things, for
example,
evaluate_batch()
accepts an optional per-examplemask
so it can workon padded batches.- We can define a
PerDomainMetric
metric for any base metric so that we can getaccuracy where examples are partitioned by a domain id.
Creating a new Metric
Most likely, a new metric will just return a MeanStat
or a SumStat
.
If that’s the case, simply follow the guidelines in Metric
‘s class docstring.
If a new Stat
is necessary, follow the guidelines in Stat
‘s docstring.
- class fedjax.metrics.Stat
Stat keeps some statistic, along with operations over them.
Most users will only need to interact with a
Stat
object viaresult()
For those who need to create new metrics, please first read the Under the hood section of the module docstring.
Most
Stat
’s domain (the set of possible statistic values) has constraints, it is thus usually a good practice to offer and use factory methods to construct newStat
objects instead of directly assigning the fields.To work with various jax constructs, a concrete
Stat
should be a PyTree. This is easily achieved withfedjax.dataclass
.A
Stat
may hold either a single statistic (a rank 0Stat
), or an array of statistics (a higher rankStat
).result()
andmerge()
only needs to work on a rank 0Stat
reduce()
only needs to work on a higher rankStat
- abstract merge(other)
Merges two Stat objects into a new Stat with merged statistics.
- abstract reduce(axis=0)
Reduces a higher rank statistic along a given
axis
.See the class docstring for details.
- Parameters:
axis (
Optional
[int
]) – An integer axis index, orNone
.- Return type:
- Returns:
A new Stat object of the same type.
- class fedjax.metrics.MeanStat(accum, weight)
Bases:
Stat
Statistic for weighted mean calculation.
Prefer using the
MeanStat.new()
factory method instead of directly assigning to fields.Example:
stat_0 = MeanStat.new(accum=1, weight=2) stat_1 = MeanStat.new(accum=2, weight=3) merged_stat = stat_0.merge(stat_1) print(merged_stat) # MeanState(accum=3, weight=5) => 0.6 stat = MeanStat.new(jnp.array([1, 2, 4]), jnp.array([1, 1, 0])) reduced_stat = stat.reduce() print(reduced_stat) # MeanStat(accum=3, weight=2) => 1.5
- classmethod new(accum, weight)
Creates a sanitized MeanStat.
The domain of a weighted mean statistic is:
\[\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}\]new() sanitizes values outside the domain into the identity (zeros).
- Parameters:
accum – A value convertible to
jnp.ndarray
.weight – A value convertible to
jnp.ndarray
.
- Return type:
- Returns:
The sanitized MeanStat.
- class fedjax.metrics.SumStat(accum)
Bases:
Stat
Statistic for summing values.
Example:
stat_0 = SumStat.new(accum=1) stat_1 = SumStat.new(accum=2) merged_stat = stat_0.merge(stat_1) print(merged_stat) # SumStat(accum=3) => 3 stat = SumStat.new(jnp.array([1, 2, 1])) reduced_stat = stat.reduce() print(reduced_stat) # SumStat(accum=4) => 4
- class fedjax.metrics.Metric
Metric is the conceptual metric (like accuracy).
It defines two methods:
evaluate_example()
evaluates a single example, and returns aStat
object.zero()
returns the identity value for whatevaluate_example()
returns.
Given a
Metric
objectm
, letu = m.zero()
v = m.evaluate_example(...)
We require that
type(u) == type(v)
.u.merge(v) == v.merge(u) == v
.Components of
u
has the same shape as the counter parts inv
.
- abstract evaluate_example(example, prediction)
Evaluates a single example.
e.g. for accuracy:
MeanStat.new(num_correct, num_total)
- class fedjax.metrics.CrossEntropyLoss(target_key='y', pred_key=None)
Bases:
Metric
Metric for cross entropy loss.
Example:
example = {'y': jnp.array(1)} prediction = jnp.array([1.2, 0.4]) metric = CrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.1711007, weight=1) => 1.1711007
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.Accuracy(target_key='y', pred_key=None)
Bases:
Metric
Metric for accuracy.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0, 0, 1]) metric = Accuracy() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1, weight=1) => 1
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.TopKAccuracy(k, target_key='y', pred_key=None)
Bases:
Metric
Metric for top k accuracy.
This metric computes the number of times where the correct class is among the top k classes predicted.
Example: top 3 accuracy
Dog => [Dog, Cat, Bird, Mouse, Penguin] ✓
Cat => [Bird, Mouse, Cat, Penguin, Dog] ✓
Dog => [Dog, Cat, Bird, Penguin, Mouse] ✓
Bird => [Bird, Cat, Mouse, Penguin, Dog] ✓
Cat => [Cat, Bird, Mouse, Dog, Penguin] ✓
Cat => [Cat, Mouse, Dog, Penguin, Bird] ✓
Mouse => [Penguin, Cat, Dog, Mouse, Bird] x
Penguin => [Dog, Mouse, Cat, Penguin, Bird] x
6 correct predictions in top 3 predicted classes / 8 total examples = .75 top 3 accuracy
Top k accuracy, also known as top n accuracy, is a useful metric when it comes to recommendations. One example would be the word recommendations on a virtual keyboard where three suggested words are displayed.
For k=1, we strongly recommend using
Accuracy
to avoid an unnecessary argsort. k < 1 will return 0. and k >= num_classes will return 1.If two or more classes have the same prediction, the classes will be considered in order of lowest to highest indices.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0, 0.5, 0.2]) metric = TopKAccuracy(k=2) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1, weight=1) => 1
- k
Number of top elements to look at for computing accuracy.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.SequenceTokenCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,), per_position=False)
Bases:
Metric
Metric for token cross entropy loss for a sequence example.
Example:
example = {'y': jnp.array([1, 0, 1])} prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) metric = SequenceTokenCrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.2246635, weight=2) => 0.61233175 per_position_metric = SequenceTokenCrossEntropyLoss(per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1.1711007, 0., 0.05356275], weight=[1., 0., 1.]) => [1.1711007, 0., 0.05356275]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,))
Bases:
Metric
Metric for total cross entropy loss for a sequence example.
Example:
example = {'y': jnp.array([1, 0, 1])} prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) metric = SequenceCrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.2246635, weight=1) => 1.2246635
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTokenAccuracy(target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)
Bases:
Metric
Metric for token accuracy for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 1, 3, 0])} # prediction = [1, 0, 2, 1, 3, 0]. prediction = jnp.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0]]) logits_mask = (0., 0., 0., jnp.NINF) metric = SequenceTokenAccuracy(logits_mask=logits_mask) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=3, weight=5) => 0.6 per_position_metric = SequenceTokenAccuracy(logits_mask=logits_mask, per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- logits_mask
Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.
- Type:
Optional[Tuple[float, …]]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceTokenTopKAccuracy(k, target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)
Bases:
Metric
Metric for token top k accuracy for a sequence example.
For more information on the top k accuracy metric, refer to the
TopKAccuracy
docstring.Example:
example = {'y': jnp.array([1, 2, 2, 1, 3, 0])} prediction = jnp.array([[0, 1, 0.5, 0], [1, 0.5, 0, 0], [0.8, 0, 0.7, 0], [0.5, 1, 0, 0], [0, 0.5, 0, 1], [0.5, 0, 0.9, 0]]) logits_mask = (0., 0., 0., jnp.NINF) metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=3, weight=5) => 0.6 per_position_metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask, per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
- k
Number of top elements to look at for computing accuracy.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- logits_mask
Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.
- Type:
Optional[Tuple[float, …]]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceTokenCount(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for count of non masked tokens for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceTokenCount(masked_target_values=(0, 2)) print(metric.evaluate_example(example, prediction)) # SumStat(accum=3) => 3
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceCount(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for count of non masked sequences.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} empty_example = {'y': jnp.array([0, 0, 0, 0, 0, 0, 0])} prediction = jnp.array([]) # Unused. metric = metrics.SequenceCount(masked_target_values=(0, 2)) print(metric.evaluate_example(example, prediction)) # SumStat(accum=1) print(metric.evaluate_example(empty_example, prediction)) # SumStat(accum=0)
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTruncationRate(eos_target_value, target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for truncation rate for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 4])} truncated_example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 3])} prediction = jnp.array([]) # Unused. metric = SequenceTruncationRate(eos_target_value=4) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=0, weight=1) => 0 print(metric.evaluate_example(truncated_example, prediction)) # MeanStat(accum=1, weight=1) => 1
- eos_target_value
Target value denoting end of sequence. Truncated sequences will not have this value.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTokenOOVRate(oov_target_values, target_key='y', masked_target_values=(0,), per_position=False)
Bases:
Metric
Metric for out-of-vocabulary (OOV) rate for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceTokenOOVRate(oov_target_values=(2,)) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=2, weight=5) => 0.4 per_position_metric = SequenceTokenOOVRate(oov_target_values=(2,), per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[0., 1., 1., 0., 0., 0., 0.], weight=[1., 1., 1., 1., 1., 0., 0.]) => [0. 1. 1. 0. 0. 0. 0.]
- oov_target_values
Target values denoting out-of-vocabulary values.
- Type:
Tuple[int, …]
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceLength(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for length for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceLength() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=4, weight=1) => 4
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.PerDomainMetric(base, num_domains, domain_id_key='domain_id')
Bases:
Metric
Turns a base metric into one that groups results by domain.
This is useful in algorithms such as AgnosticFedAvg.
example
is expected to contain a feature nameddomain_id_key
, which stores the integer domain id in [0, num_domains). PerDomain accumulatesbase
‘sStat
within each domain. If the baseMetric
returns aStat
whose result is of shape X, then theStat
returned by PerDomain will produce a result of shape(num_domains,) + X
. See Batching Stat s for the higher rankStat
mechanism enabling this.Example:
per_domain_accuracy = PerDomain(metrics.Accuracy(), num_domains=3) batch_example = { 'domain_id': jnp.array([0, 0, 1, 2]), 'y': jnp.array([0, 1, 0, 1]) } batch_prediction = jnp.array([[0., 1.], [2., 3.], [4., 5.], [6., 7.]]) print( evaluate_batch(per_domain_accuracy, batch_example, batch_prediction).result()) # [0.5 0. 1. ]
- class fedjax.metrics.ConfusionMatrix(num_classes, target_key='y', pred_key=None)
Bases:
Metric
Metric for making a Confusion Matrix.
A confusion matrix is an nxn matrix often used to describe the performance of a classification model on a set of test data for which the true values are known. The model’s predictions are represented through the columns, and the known data values through the rows. This allows one to view in which areas the model is doing well, as well as where there is room for improvement. For each row in the confusion matrix, if there are a lot of numbers outside of the main diagonal, the model is not doing so well in respect to when it is supposed to output that row’s relative output class.
Theoretical Example:
Predicted P Predicted N Actual P TP FN Actual N FP TN **This is for a binary classification model but the same concept applies to any model with n outputs. Notice that the TPs and TNs will always lie in the main diagonal of the matrix.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0., 1., 0.]) metric = ConfusionMatrix(num_classes=3) print(metric.evaluate_example(example, prediction)) # SumStat(accum=DeviceArray([[0., 0., 0.], # [0., 0., 0.], # [0., 1., 0.]], dtype=float32)) => [[0. 0. 0.] # [0. 0. 0.] # [0. 1. 0.]]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- num_classes
Number of output classes of the model. Used to generate a matrix of shape [num_classes, num_classes].
- Type:
int
- fedjax.metrics.unreduced_cross_entropy_loss(targets, preds, is_sparse_targets=True)
Returns unreduced cross entropy loss.
- Return type:
fedjax.optimizers
Lightweight library for working with optimizers.
Wraps different optimizer libraries in a common interface. |
|
Creates optimizer from optax gradient transformation chain. |
|
Modifies |
The Adagrad optimizer. |
|
The classic Adam optimizer. |
|
A flexible RMSProp optimizer. |
|
A canonical Stochastic Gradient Descent optimizer. |
- class fedjax.optimizers.Optimizer(init, apply)
Wraps different optimizer libraries in a common interface.
Works with optax.
The expected usage of Optimizer is as follows:
# One step of SGD. params = {'w': jnp.array([1, 1, 1])} grads = {'w': jnp.array([2, 3, 4])} optimizer = fedjax.optimizers.sgd(learning_rate=0.1) opt_state = optimizer.init(params) opt_state, params = optimizer.apply(grads, opt_state, params) print(params) # {'w': DeviceArray([0.8, 0.7, 0.6], dtype=float32)}
- init
Initializes (possibly empty) PyTree of statistics (optimizer state) given the input model parameters.
- Type:
Callable[[Any], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]
- apply
Transforms and applies the input gradients to update the optimizer state and model parameters.
- Type:
Callable[[Any, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any], Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any]]
- fedjax.optimizers.create_optimizer_from_optax(opt)
Creates optimizer from optax gradient transformation chain.
- Return type:
- fedjax.optimizers.ignore_grads_haiku(optimizer, non_trainable_names)
Modifies
optimizer
to ignore gradients fornon_trainable_names
.Non-trainable parameters will have their values set to
None
when passed as input into the Optimizer to prevent any updates.NOTE: This will only work with models implemented in haiku.
- Parameters:
optimizer (
Optimizer
) – Base Optimizer.non_trainable_names (
List
[Tuple
[str
,str
]]) – List of tuples of haiku module names and names of given entries in the module data bundle (e.g. parameter name). This list of names will be used to select the non-trainable parameters.
- Return type:
- Returns:
Optimizer that will ignore gradients for the non-trainable parameters.
- fedjax.optimizers.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-06)
The Adagrad optimizer.
Adagrad is an algorithm for gradient based optimisation that anneals the learning rate for each parameter during the course of training.
WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.
References
[Duchi et al, 2011](https://jmlr.org/papers/v12/duchi11a.html)
- Parameters:
learning_rate (
Union
[float
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – This is a fixed global scaling factor.initial_accumulator_value (
float
) – Initialisation for the accumulator.eps (
float
) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.
- Return type:
- Returns:
The corresponding Optimizer.
- fedjax.optimizers.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)
The classic Adam optimizer.
Adam is an SGD variant with learning rate adaptation. The learning_rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).
References
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
- Parameters:
learning_rate (
Union
[float
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – This is a fixed global scaling factor.b1 (
float
) – The exponential decay rate to track the first moment of past gradients.b2 (
float
) – The exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.
- Return type:
- Returns:
The corresponding Optimizer.
- fedjax.optimizers.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)
A flexible RMSProp optimizer.
RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.
References
[Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [Graves, 2013](https://arxiv.org/abs/1308.0850)
- Parameters:
learning_rate (
Union
[float
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – This is a fixed global scaling factor.decay (
float
) – The decay used to track the magnitude of previous gradients.eps (
float
) – A small numerical constant to avoid dividing by zero when rescaling.initial_scale (
float
) – Initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.centered (
bool
) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.momentum (
Optional
[float
]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether nesterov momentum is used.
- Return type:
- Returns:
The corresponding Optimizer.
- fedjax.optimizers.sgd(learning_rate, momentum=None, nesterov=False)
A canonical Stochastic Gradient Descent optimizer.
This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
References
[Sutskever et al, 2013](http://proceedings.mlr.press/v28/sutskever13.pdf)
- Parameters:
learning_rate (
Union
[float
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – This is a fixed global scaling factor.momentum (
Optional
[float
]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether nesterov momentum is used.
- Return type:
- Returns:
The corresponding Optimizer.
fedjax.tree_util
Utilities for working with tree-like container data structures.
In JAX, the term pytree refers to a tree-like structure built out of container-like Python objects. For more details, see https://jax.readthedocs.io/en/latest/pytrees.html.
- fedjax.tree_util.tree_mean(pytrees_and_weights)
Returns (weighted) mean of input trees and weights.
- Parameters:
pytrees_and_weights (
Iterable
[Tuple
[Any
,float
]]) – Iterable of tuples of pytrees and associated weights.- Return type:
Any
- Returns:
(Weighted) mean of input trees and weights.
- fedjax.tree_util.tree_weight(pytree, weight)
Weights tree leaves by weight.
- Return type:
Any
- fedjax.tree_util.tree_sum(pytrees)
Sums multiple trees together.
- Return type:
Any
- fedjax.tree_util.tree_add(left, right)
Adds two trees together.
- Return type:
Any
- fedjax.tree_util.tree_zeros_like(pytree)
Creates a tree with zeros with same structure as the input.
- Return type:
Any
- fedjax.tree_util.tree_inverse_weight(pytree, weight)
Weights tree leaves by
1 / weight
.- Return type:
Any
- fedjax.tree_util.tree_size(pytree)
Returns total size of all tree leaves.
- Return type:
int
- fedjax.tree_util.tree_clip_by_global_norm(pytree, max_norm)
Clips a pytree of arrays using their global norm.
References
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
- Parameters:
pytree (
Any
) – A pytree to be potentially clipped.max_norm (
float
) – The maximum global norm for a pytree.
- Return type:
Any
- Returns:
A potentially clipped pytree.
Federated algorithm
Container for all federated algorithms. |
Federated data
FederatedData interface for providing access to a federated dataset. |
|
A simple wrapper over a concrete FederatedData for restricting to a subset of client ids. |
|
Federated dataset backed by SQLite. |
|
A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets. |
|
FederatedDataBuilder interface. |
|
Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples. |
|
A chain of preprocessing functions on all examples of a client dataset. |
|
Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline. |
|
Padded batch all client datasets, useful for evaluation on the entire federated dataset. |
Client dataset
In memory client dataset backed by numpy ndarrays. |
|
A chain of preprocessing functions on batched examples. |
|
Shuffles and batches examples from multiple client datasets. |
|
Batches examples from multiple client datasets. |
For each client
Creates a function which maps over clients. |
|
A context manager for switching to a given ForEachClientBackend in the current thread. |
|
Sets the for_each_client backend for the current thread. |
Model
Container class for models. |
|
Creates Model after applying defaults and haiku specific preprocessing. |
|
Creates Model after applying defaults and stax specific preprocessing. |
|
Evaluates model for multiple batches and returns final results. |
|
A standard gradient function derived from a model and an optional regularizer. |
|
Convenience function for constructing a per-example loss function from a model. |
|
Evaluates the average per example loss over multiple batches. |
|
Evaluates model for each client dataset, either using global params, or per client params. |
|
Evaluates average loss for each client dataset, either using global params, or per client params. |
|
A standard gradient function derived from per-example loss and an optional regularizer. |
- class fedjax.FederatedAlgorithm(init, apply)[source]
Container for all federated algorithms.
FederatedAlgorithm defines the required methods that need to be implemented by all custom federated algorithms. Defining federated algorithms in this structure will allow implementations to work seamlessly with the convenience methods in the
fedjax.training
API, like checkpointing.Example toy implementation:
# Federated algorithm that just counts total number of points across clients # across rounds. def count_federated_algorithm(): def init(init_count): return {'count': init_count} def apply(state, clients): count = 0 client_diagnostics = {} # Count sizes across clients. for client_id, client_dataset, _ in clients: # Summation across clients in one round. client_count = len(client_dataset) count += client_count client_diagnostics[client_id] = client_count # Summation across rounds. state = {'count': state['count'] + count} return state, client_diagnostics return FederatedAlgorithm(init, apply) rng = jax.random.PRNGKey(0) # Unused. all_clients = [ [ (b'cid0', ClientDataset({'x': jnp.array([1, 2, 1, 2, 3, 4])}), rng), (b'cid1', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5])}), rng), (b'cid2', ClientDataset({'x': jnp.array([1, 1, 2])}), rng), ], [ (b'cid3', ClientDataset({'x': jnp.array([1, 2, 3, 4])}), rng), (b'cid4', ClientDataset({'x': jnp.array([1, 1, 2, 1, 2, 3])}), rng), (b'cid5', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5, 6, 7])}), rng), ], ] algorithm = count_federated_algorithm() state = algorithm.init(0) for round_num in range(2): state, client_diagnostics = algorithm.apply(state, all_clients[round_num]) print(round_num, state) print(round_num, client_diagnostics) # 0 {'count': 14} # 0 {b'cid0': 6, b'cid1': 5, b'cid2': 3} # 1 {'count': 31} # 1 {b'cid3': 4, b'cid4': 6, b'cid5': 7}
- init
Initializes the
ServerState
. Typically, the input to this method will be the initial modelParams
. This should only be run once at the beginning of training.- Type:
Callable[[…], Any]
- apply
Completes one round of federated training given an input
ServerState
and a sequence of tuples of client identifier, client dataset, and client rng. The output will be a new, updatedServerState
and accumulated per step results keyed by client identifier (e.g. train metrics).- Type:
Callable[[Any, Sequence[Tuple[bytes, fedjax.core.client_datasets.ClientDataset, jax.Array]]], Tuple[Any, Mapping[bytes, Any]]]
- class fedjax.FederatedData[source]
FederatedData interface for providing access to a federated dataset.
A FederatedData object serves as a mapping from client ids to client datasets and client metadata.
Access methods with better I/O efficiency
For large federated datasets, it is not feasible to load all client datasets into memory at once (whereas loading a single client dataset is assumed to be feasible). Different implementations exist for different on disk storage formats. Since sequential read is much faster than random read for most storage technologies, FederatedData provides two types of methods for accessing client datasets,
clients()
andshuffled_clients()
are sequential read friendly, and thus recommended whenever appropriate.get_clients()
requires random read, but prefetching is possible. This should be preferred overget_client()
.get_client()
is usually the slowest way of accessing client datasets, and is mostly intended for interactive exploration of a small number of clients.
Preprocessing
ClientDataset
produced by FederatedData can hold aBatchPreprocessor
, customizable viapreprocess_batch()
. Additionally, another “client” levelClientPreprocessor
, customizable viapreprocess_client()
, can be used to apply transformations on examples from the entire client dataset before aClientDataset
is constructed.- abstract client_ids()[source]
Returns an iterator of client ids as bytes.
There is no requirement on the order of iteration.
- Return type:
Iterator
[bytes
]
- abstract client_size(client_id)[source]
Returns the number of examples in a client dataset.
- Return type:
int
- abstract client_sizes()[source]
Returns an iterator of all (client id, client size) pairs.
This is often more efficient than making multiple
client_size()
calls. There is no requirement on the order of iteration.- Return type:
Iterator
[Tuple
[bytes
,int
]]
- abstract clients()[source]
Iterates over clients in a deterministic order.
Implementation can choose whatever order that makes iteration efficient.
- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]
- abstract get_client(client_id)[source]
Gets one single client dataset.
Prefer
clients()
,shuffled_clients()
, orget_clients()
when possible.- Parameters:
client_id (
bytes
) – Client id to load.- Return type:
- Returns:
The corresponding ClientDataset.
- abstract get_clients(client_ids)[source]
Gets multiple clients in order with one call.
Clients are returned in the order of
client_ids
.- Parameters:
client_ids (
Iterable
[bytes
]) – Client ids to load.- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]- Returns:
Iterator.
- abstract num_clients()[source]
Returns the number of clients.
If it is too expensive or otherwise impossible to obtain the result, an implementation may raise an exception.
- Return type:
int
- abstract preprocess_batch(fn)[source]
Registers a preprocessing function to be called after batching in ClientDatasets.
- Return type:
- abstract preprocess_client(fn)[source]
Registers a preprocessing function to be called on all examples of a client before passing them to construct a ClientDataset.
- Return type:
- abstract shuffled_clients(buffer_size, seed=None)[source]
Iterates over clients with a repeated buffered shuffling.
Shuffling should use a buffer size of at least
buffer_size
clients. The iteration should repeat forever, with usually a different order in each pass.- Parameters:
buffer_size (
int
) – Buffer size for shuffling.seed (
Optional
[int
]) – Optional random number generator seed.
- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]- Returns:
Iterator.
- abstract slice(start=None, stop=None)[source]
Returns a new FederatedData restricted to client ids in the given range.
The returned FederatedData includes clients whose ids are,
Greater than or equal to
start
whenstart
is not None;Less than
stop
whenstop
is not None.
- Parameters:
start (
Optional
[bytes
]) – Start of client id range.stop (
Optional
[bytes
]) – Stop of client id range.
- Return type:
- Returns:
FederatedData.
- class fedjax.SubsetFederatedData(base, client_ids, validate=True)[source]
Bases:
FederatedData
A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.
This is useful when we wish to create a smaller FederatedData out of arbitrary client ids, where slicing is not possible.
- __init__(base, client_ids, validate=True)[source]
Initializes the subset federated dataset.
- Parameters:
base (
FederatedData
) – Base concrete FederatedData.client_ids (
Iterable
[bytes
]) – Client ids to include in the subset. All client ids must be in base.client_ids(), otherwise the behavior of SubsetFederatedData is undefined when validate=False.validate – Whether to validate client ids.
- class fedjax.SQLiteFederatedData(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Bases:
FederatedData
Federated dataset backed by SQLite.
The SQLite database should contain a table named “federated_data” created with the following command:
CREATE TABLE federated_data ( client_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, num_examples INTEGER NOT NULL );
where,
client_id is the bytes client id.
data is the serialized client dataset examples.
num_examples is the number of examples in the client dataset.
By default we use zlib compressed msgpack blobs for data (see decompress_and_deserialize()).
- __init__(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
- static new(path, parse_examples=<function decompress_and_deserialize>)[source]
Opens a federated dataset stored as an SQLite3 database.
- Parameters:
path (
str
) – Path to the SQLite database file.parse_examples (
Callable
[[bytes
],Mapping
[str
,ndarray
]]) – Function for deserializing client dataset examples.
- Return type:
- Returns:
SQLite3DataSource.
- class fedjax.InMemoryFederatedData(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Bases:
FederatedData
A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.
This is useful when we wish to create a smaller FederatedData that fits in memory. Here is a simple example to create a fedjax.InMemoryFederatedData,
client_a_data = { 'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': np.array([7, 8]) } client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])} client_to_data_mapping = {'a': client_a_data, 'b': client_b_data} fedjax.InMemoryFederatedData(client_to_data_mapping)
- Returns:
A fedjax.InMemoryDataset corresponding to client_to_data_mapping.
- __init__(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Initializes the in memory federated dataset.
Data of each client is a mapping from feature names to numpy arrays. For example, for emnist image classification, {‘x’: X, ‘y’: y}, where X is a matrix of shape (num_data_points, 28, 28) and y is a matrix of shape (num_data_points).
- Parameters:
client_to_data_mapping (
Mapping
[bytes
,Mapping
[str
,ndarray
]]) – A mapping from client_id to data of each client.preprocess_client (
ClientPreprocessor
) – federated_data.ClientPreprocessor to preprocess each client data.preprocess_batch (
BatchPreprocessor
) – client_datasets.BatchPreprocessor to preprocess batch of data.
- class fedjax.FederatedDataBuilder[source]
FederatedDataBuilder interface.
To be implemented as a context manager for building file formats from pairs of client IDs and client NumPy examples.
It is relevant to note that the add method below does not specify any raised exceptions. One could imagine some formats where add can fail in some way: out-of-order or duplicate inputs, remote files and network failures, individual entries too big for a format, etc. In order to address this we let implementations throw whatever they see relevant and fit to their particular use cases. The same is relevant when it comes to the __init__, __enter__, and __exit__ methods, where implementations are left with the responsibility of raising exceptions as they see fit to their particular use cases. For example, if an invalid file path is passed, or there were any issues finalizing the builder, etc.
Eg of end behavior:
with FederatedDataBuilder(path) as builder: builder.add(b'k1', np.array([b'v1'], dtype=np.object)) builder.add(b'k2', np.array([b'v2'], dtype=np.object))
- class fedjax.SQLiteFederatedDataBuilder(path)[source]
Bases:
FederatedDataBuilder
Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.
- class fedjax.ClientPreprocessor(fns=())[source]
A chain of preprocessing functions on all examples of a client dataset.
This is very similar to
fedjax.BatchPreprocessor
, with the main difference being that ClientPreprocessor also takesclient_id
as input.See the discussion in
fedjax.BatchPreprocessor
regarding when to use which.
- fedjax.shuffle_repeat_batch_federated_data(fd, batch_size, client_buffer_size, example_buffer_size, seed=None)[source]
Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.
Shuffling is done using two levels of buffered shuffling, first at the client level, then at the example level.
This produces an infinite stream of batches. itertools.islice() can be used to cap the number of batches, if so desired.
- Parameters:
fd (
FederatedData
) – Federated dataset.batch_size (
int
) – Desired batch size.client_buffer_size (
int
) – Buffer size for client level shuffling.example_buffer_size (
int
) – Buffer size for example level shuffling.seed (
Optional
[int
]) – Optional RNG seed.
- Yields:
Batches of preprocessed examples.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.padded_batch_federated_data(fd, hparams=None, **kwargs)[source]
Padded batch all client datasets, useful for evaluation on the entire federated dataset.
- Parameters:
fd (
FederatedData
) – Federated dataset.hparams (
Optional
[PaddedBatchHParams
]) – Seefedjax.padded_batch_client_datasets()
.**kwargs – See
fedjax.padded_batch_client_datasets()
.
- Yields:
Batches of preprocessed examples.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- class fedjax.RepeatableIterator(base)[source]
Repeats a base iterable after the end of the first pass of iteration.
Because this is a stateful object, it is not thread safe, and all usual caveats about accessing the same iterator at different locations apply. For example, if we make two map calls to the same RepeatableIterator, we must make sure we do not interleave next() calls on these. For example, the following is safe because we finish iterating on m1 before starting to iterate on m2.,
it = RepeatableIterator(range(4)) m1 = map(lambda x: x + 1, it) m2 = map(lambda x: x * x, it) # We finish iterating on m1 before starting to iterate on m2. print(list(m1), list(m2)) # [1, 2, 3, 4] [0, 1, 4, 9]
Whereas interleaved access leads to confusing results,
it = RepeatableIterator(range(4)) m1 = map(lambda x: x + 1, it) m2 = map(lambda x: x * x, it) print(next(m1), next(m2)) # 1 1 print(next(m1), next(m2)) # 3 9 print(next(m1), next(m2)) # StopIteration!
In the first pass of iteration, values fetched from the base iterator will be copied onto an internal buffer (except for a few builtin containers where copying is unnecessary). When each pass of iteration finishes (i.e. when __next__() raises StopIteration), the iterator resets itself to the start of the buffer, thus allowing a subsequent pass of repeated iteration.
In most cases, if repeated iterations are required, it is sufficient to simply copy values from an iterator into a list. However, sometimes an iterator produces values via potentially expensive I/O operations (e.g. loading client datasets), RepeatableIterator can interleave I/O and JAX compute to decrease accelerator idle time in this case.
Preprocessing and batching operations over client datasets.
Column based representation
The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).
We use a column based representation when loading a dataset into memory.
Each column is a NumPy array
x
of rank at least 1, wherex[i, ...]
is the value of this feature for thei
-th example.The complete set of examples is a dict-like object, from
str
feature names, to the corresponding column values.
Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch. In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.
Preprocessor
Preprocessing on a batch of examples can be easily done via a chain of
functions. A Preprocessor
object holds the chain of functions, and applies
the transformation on a batch of examples.
ClientDataset: examples + preprocessor
A ClientDataset
is simply some examples in the column based
representation, accompanied by a Preprocessor.
Its batch()
method produces batches of examples in a
sequential order, suitable for evaluation.
Its shuffle_repeat_batch()
method adds shuffling and
repeating, making it suitable for training.
- class fedjax.ClientDataset(raw_examples, preprocessor=BatchPreprocessor(()))[source]
In memory client dataset backed by numpy ndarrays.
Custom preprocessing on batches can be added via a preprocessor. A ClientDataset is stored as the unpreprocessed
raw_examples
, along with its preprocessor.To access batches, use one of the batching functions (e.g.
shuffle_repeat_batch()
for training,padded_batch()
for evaluation).To access a small number of preprocessed examples (e.g. for exploration), use slicing +
all_examples()
.
This is only intended for efficient access to small datasets that fit in memory.
- all_examples()[source]
Returns the result of feeding all raw examples through the preprocessor.
This is mostly intended for interactive exploration of a small subset of a client dataset. For example, to see the first 4 examples in a client dataset,
dataset = ClientDataset(my_raw_examples, my_preprocessor) dataset[:4].all_examples()
- Return type:
Mapping
[str
,ndarray
]- Returns:
Preprocessed examples from all the raw examples in this client dataset.
- batch(hparams=None, **kwargs)[source]
Produces preprocessed batches in a fixed sequential order.
The final batch may contain fewer than
batch_size
examples. If used directly, that may result in a large number of JIT recompilations. Therefore we recommended usingpadded_batch
instead in most scenarios.This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code if you have to use batch (prefer padded_batch() if possible). Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.batch(batch_size=2) # Overrides the default drop_remainder value. client_dataset.batch(hparams, drop_remainder=True)
- Parameters:
hparams (
Optional
[BatchHParams
]) – Batching hyperparameters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- padded_batch(hparams=None, **kwargs)[source]
Produces preprocessed padded batches in a fixed sequential order.
This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.padded_batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.padded_batch(batch_size=2) # Overrides the default num_batch_size_buckets value. client_dataset.padded_batch(hparams, num_batch_size_buckets=2)
When the number of examples in the dataset is not a multiple of
batch_size
, the final batch may be smaller thanbatch_size
. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled bynum_batch_size_buckets
.All batches contain an extra bool feature keyed by
EXAMPLE_MASK_KEY
.batch[EXAMPLE_MASK_KEY][i]
tells us whether thei
-th example in this batch is an actual example (batch[EXAMPLE_MASK_KEY][i] == True
), or a padding example (batch[EXAMPLE_MASK_KEY][i] == False
).We repeatedly halve the batch size up to
num_batch_size_buckets-1
times, until we find the smallest one that is also >= the size of the final batch. Therefore ifbatch_size < 2^num_batch_size_buckets
, fewer bucket sizes will be actually used.- Parameters:
hparams (
Optional
[PaddedBatchHParams
]) – Batching hyperparameters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- shuffle_repeat_batch(hparams=None, **kwargs)[source]
Produces preprocessed batches in a shuffled and repeated order.
This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.shuffle_repeat_batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.shuffle_repeat_batch(batch_size=2) # Overrides the default num_epochs value. client_dataset.shuffle_repeat_batch(hparams, num_epochs=2)
Shuffling is done without replacement, therefore for a dataset of N examples, the first
ceil(N/batch_size)
batches are guarranteed to cover the entire dataset.By default the iteration stops after the first epoch. The number of batches produced from the iteration can be controlled by the
(num_epochs, num_steps, drop_remainder)
combination:If both
num_epochs
andnum_steps
are None, the shuffle-repeat process continues forever.If
num_epochs
is set andnum_steps
is None, as few batches as needed to go over the dataset this many passes are produced. Further,If
drop_remainder
is False (the default), the final batch is filled with additionally sampled examples to containbatch_size
examples.If
drop_remainder
is True, the final batch is dropped if it contains fewer thanbatch_size
examples. This may result in examples being skipped whennum_epochs=1
.
If
num_steps
is set andnum_steps
is None, exactly this many batches are produced.drop_remainder
has no effect in this case.If both
num_epochs
andnum_steps
are set, the fewer number of batches between the two conditions are produced.
If reproducible iteration order is desired, a fixed
seed
can be used. Whenseed
is None, repeated iteration over the same object may produce batches in a different order.Unlike
batch()
orpadded_batch()
, batches fromshuffle_repeat_batch()
always contain exactlybatch_size
examples. Also unlike TensorFlow, that holds even whendrop_remainder=False
.- Parameters:
hparams (
Optional
[ShuffleRepeatBatchHParams
]) – Batching hyperparamters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- class fedjax.BatchPreprocessor(fns=())[source]
A chain of preprocessing functions on batched examples.
BatchPreprocessor holds a chain of preprocessing functions, and applies them in order on batched examples. Each individual preprocessing function operates over multiple examples, instead of just 1 example. For example,
preprocessor = BatchPreprocessor([ # Flattens `pixels`. lambda x: {**x, 'pixels': x['pixels'].reshape([-1, 28 * 28])}, # Introduce `binary_label`. lambda x: {**x, 'binary_label': x['label'] % 2}, ]) fake_emnist = { 'pixels': np.random.uniform(size=(10, 28, 28)), 'label': np.random.randint(10, size=(10,)) } preprocessor(fake_emnist) # Produces a dict of [10, 28*28] "pixels", [10,] "label" and "binary_label".
Given a BatchPreprocessor, a new BatchPreprocessor can be created with an additional preprocessing function appended to the chain,
# Continuing from the previous example. new_preprocessor = preprocessor.append( lambda x: {**x, 'sum_pixels': np.sum(x['pixels'], axis=1)}) new_preprocessor(fake_emnist) # Produces a dict of [10, 28*28] "pixels", [10,] "sum_pixels", "label" and # "binary_label".
The main difference of this preprocessor and
fedjax.ClientPreprocessor
is thatfedjax.ClientPreprocessor
also takesclient_id
as input. Because of the identical representation between batched examples and all examples in a client dataset, certain preprocessing can be done with either BatchPreprocessor or ClientPreprocessor.Examples of preprocessing possible at either the client dataset level, or the batch level
Such preprocessing is deterministic, and strictly per-example.
Casting a feature from int8 to float32.
Adding a new feature derived from existing features.
Remove a feature (although the better place to do so is at the dataset level).
A simple rule for deciding where to carry out the preprocessing in this case is the following,
Does this make batching cheaper (e.g. removing features)? If so, do it at the dataset level.
Otherwise, do it at the batch level.
Assuming preprocessing time is linear in the number of examples, preprocessing at the batch level has the benefit of evenly distributing host compute work, which may overlap better with asynchronous JAX compute work on GPU/TPU.
Examples of preprocessing only possible at the batch level
Data augmentation (e.g. random cropping).
Padding at the batch size dimension.
Examples of preprocessing only possible at the dataset level
Those that require knowing the client id.
Capping the number of examples.
Altering what it means to be an example: e.g. in certain language model setups, sentences are concatenated and then split into equal sized chunks.
- fedjax.buffered_shuffle_batch_client_datasets(datasets, batch_size, buffer_size, rng)[source]
Shuffles and batches examples from multiple client datasets.
This just makes 1 pass over the examples. To achieve repeated iterations, create an infinite shuffled stream of datasets first (e.g. using buffered_shuffle()).
- Parameters:
datasets (
Iterable
[ClientDataset
]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.batch_size (
int
) – Desired batch size.buffer_size (
int
) – Number of examples to buffer during shuffling.rng (
RandomState
) – Source of randomness.
- Yields:
Batches of examples. For a finite stream of datasets, the final batch might be smaller than
batch_size
.- Raises:
ValueError – If any 2 client datasets have different Preprocessors.
ValueError – If any 2 client datasets have different features.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.padded_batch_client_datasets(datasets, hparams=None, **kwargs)[source]
Batches examples from multiple client datasets.
This is useful when we want to evaluate on the combined dataset consisting of multiple client datasets. Unlike batching each client dataset individually, we can reduce the number of batches smaller than
batch_size
.This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(datasets, hparams): for batch in padded_batch_client_datasets(datasets, hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
padded_batch_client_datasets(datasets, hparams) # Overrides the default num_batch_size_buckets value. padded_batch_client_datasets(datasets, hparams, num_batch_size_buckets=2)
- Parameters:
datasets (
Iterable
[ClientDataset
]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.hparams (
Optional
[PaddedBatchHParams
]) – Batching hyperparams like those inClientDataset.padded_batch()
.**kwargs – Keyword arguments for constructing/overriding hparams.
- Yields:
Batches of examples. The final batch might be padded. All batches contain a bool feature keyed by EXAMPLE_MASK_KEY.
- Raises:
ValueError – If any 2 client datasets have different Preprocessors.
ValueError – If any 2 client datasets have different features.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.for_each_client(client_init, client_step, client_final=<function <lambda>>, with_step_result=False)[source]
Creates a function which maps over clients.
For example, for_each_client could be used to define how to run client updates for each client in a federated training round. Another common use case of for_each_client is to run evaluation per client for a given set of model parameters.
The underlying backend for for_each_client is customizable. For example, if multiple devies are available (e.g. TPU), a
jax.pmap()
based backend can be used to parallelize across devices. It’s also possible to manually specify which backend to use (for debugging).The expected usage of for_each_client is as follows:
# Map over clients and count how many points are greater than `limit` for # each client. Each client also has a different `start` that is specified # via client input. def client_init(shared_input, client_input): client_step_state = { 'limit': shared_input['limit'], 'count': client_input['start'] } return client_step_state def client_step(client_step_state, batch): num = jnp.sum(batch['x'] > client_step_state['limit']) client_step_state = { 'limit': client_step_state['limit'], 'count': client_step_state['count'] + num } return client_step_state def client_final(shared_input, client_step_state): del shared_input # Unused. return client_step_state['count'] # Three clients with different data and starting counts. # clients = [(client_id, client_batches, client_input)] clients = [ (b'cid0', [{'x': jnp.array([1, 2, 3, 4])}, {'x': jnp.array([1, 2, 3])}], {'start': jnp.array(2)}), (b'cid1', [{'x': jnp.array([1, 2])}, {'x': jnp.array([1, 2, 3, 4, 5])}], {'start': jnp.array(0)}), (b'cid2', [{'x': jnp.array([1])}], {'start': jnp.array(1)}), ] shared_input = {'limit': jnp.array(2)} func = fedjax.for_each_client(client_init, client_step, client_final) print(list(func(shared_input, clients))) # [(b'cid0', 5), (b'cid1', 3), (b'cid2', 1)]
Here’s the same example with per step results.
# We'll also keep track of the `num` per step in our step results. def client_step_with_result(client_step_state, batch): num = jnp.sum(batch['x'] > client_step_state['limit']) client_step_state = { 'limit': client_step_state['limit'], 'count': client_step_state['count'] + num } client_step_result = {'num': num} return client_step_state, client_step_result func = fedjax.for_each_client( client_init, client_step_with_result, client_final, with_step_result=True) print(list(func(shared_input, clients))) # [ # (b'cid0', 5, [{'num': 2}, {'num': 1}]), # (b'cid1', 3, [{'num': 0}, {'num': 3}]), # (b'cid2', 1, [{'num': 0}]), # ]
- Parameters:
client_init (
Callable
[[Any
,Any
],Any
]) – Function that initializes the internal intermittent client step state from the share input and per client input. The shared input contains information like the global model parameters that are shared across all clients. The per client input is per client information. The initialized internal client step state is fed as intermittent input and output from client_step and client_final. This client step state usually contains the model parameters and optimizer state for each client that are updated at each client_step. This will be run once for each client.client_step (
Union
[Callable
[[Any
,Mapping
[str
,Array
]],Tuple
[Any
,Any
]],Callable
[[Any
,Mapping
[str
,Array
]],Any
]]) – Function that takes the client step state and a batch of examples as input and outputs a (possibly updated) client step state. Optionally, per step results can also be returned as the second element if with_step_result is True. Per step results are usually diagnostics like gradient norm. This will be run for each batch for each client.client_final (
Callable
[[Any
,Any
],Any
]) – Function that applies the final transformation on the internal client step state to the desired final client output. More meaningful transformations can be done here, like model update clipping. Defaults to just returning the client step state. This will be run once for each client.with_step_result (
bool
) – Indicates whether client_step returns a pair where the first element is considered the client output and the second element is the client step result.
- Returns:
A for each client function that takes shared_input and the per client inputs as tuple (client_id, batched_client_data, client_rng) to map over and returns the outputs per client as specified in client_final along with optional per client per step results.
- fedjax.for_each_client_backend(backend)[source]
A context manager for switching to a given ForEachClientBackend in the current thread.
Example:
with for_each_client_backend('pmap'): # We will be using the pmap based for_each_client backend within this block. pass # We will be using the default for_each_client backend from now on.
- Parameters:
backend (
Union
[ForEachClientBackend
,str
,None
]) – Seeset_for_each_client_backend()
.- Yields:
Nothing.
- fedjax.set_for_each_client_backend(backend)[source]
Sets the for_each_client backend for the current thread.
- Parameters:
backend (
Union
[ForEachClientBackend
,str
,None
]) –One of the following,
None: uses the default backend for the current environment.
’debug’: uses the debugging backend.
’jit’: uses the JIT backend.
’pmap’: uses the pmap-based backend.
A concrete ForEachClientBackend object.
- class fedjax.Model(init, apply_for_train, apply_for_eval, train_loss, eval_metrics)[source]
Container class for models.
Model exists to provide easy access to predefined neural network models. It is meant to contain all the information needed for standard centralized training and evaluation. Non-standard training methods can be built upon the information avaiable in Model along with any additional information (e.g. interpolation can be implemented as a composition of two models along with an interpolation weight).
Works for Haiku and jax.example_libraries.stax.
The expected usage of Model is as follows:
# Training. step_size = 0.1 rng = jax.random.PRNGKey(0) params = model.init(rng) def loss(params, batch, rng): preds = model.apply_for_train(params, batch, rng) return jnp.sum(model.train_loss(batch, preds)) grad_fn = jax.grad(loss) for batch in batches: rng, use_rng = jax.random.split(rng) grads = grad_fn(params, batch, use_rng) params = jax.tree_util.tree_map(lambda a, b: a - step_size * b, params, grads) # Evaluation. print(fedjax.evaluate_model(model, params, batches)) # Example output: # {'loss': 2.3, 'accuracy': 0.2}
The following is an example using Model compositionally as a building block to impelement model interpolation:
def interpolate(model_1, model_2, init_weight): @jax.jit def init(rng): rng_1, rng_2 = jax.random.split(rng) params_1 = model_1.init(rng_1) params_2 = model_2.init(rng_2) return params_1, params_2, init_weight @jax.jit def apply_for_train(params, input, rng): rng_1, rng_2 = jax.random.split(rng) params_1, params_2, weight = params return (model_1.apply_for_train(params_1, input, rng_1) * weight + model_2.apply_for_train(params_1, input, rng_2) * (1 - weight)) @jax.jit def apply_for_eval(params, input): params_1, params_2, weight = params return (model_1.apply_for_eval(params_1, input) * weight + model_2.apply_for_eval(params_2, input) * (1 - weight)) return fedjax.Model(init, apply_for_train, apply_for_eval, model_1.train_loss, model_1.eval_metrics) model = interpolate(model_1, model_2, init_weight=0.5)
- init
Initialization function that takes a seed PRNGKey and returns a PyTree of initialized parameters (i.e. model weights). These parameters will be passed as input into
apply_for_train()
andapply_for_eval()
. Any trainable weights for a model that are modified in the training loop should be contained inside of these parameters.- Type:
Callable[[jax.Array], Any]
- apply_for_train
Function that takes the parameters PyTree, batch of examples, and PRNGKey as inputs and outputs the model predictions for training that are then passed into
train_loss()
. This considers strategies such as dropout.
- apply_for_eval
Function that usually takes the parameters PyTree and batch of examples as inputs and outputs the model predictions for evaluation that are then passed to
eval_metrics
. This is defined separately fromapply_for_train()
to avoid having to specify inputs like PRNGKey that are not used in evaluation.
- train_loss
Loss function for training that takes batch of examples and model output from
apply_for_train()
as input that outputs per example loss. This will typically called inside ajax.grad()
wrapped function to compute gradients.
- eval_metrics
Ordered mapping of evaluation metric names to
Metric
. TheseMetric
s are defined for single examples and will be used inevaluate_model()
- Type:
Mapping[str, fedjax.core.metrics.Metric]
- fedjax.create_model_from_haiku(transformed_forward_pass, sample_batch, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None)[source]
Creates Model after applying defaults and haiku specific preprocessing.
- Parameters:
transformed_forward_pass (
Transformed
) – Transformed forward pass fromhk.transform()
sample_batch (
Mapping
[str
,Array
]) – Example input used to determine model parameter shapes.train_loss (
Callable
[[Mapping
[str
,Array
],Array
],Array
]) – Loss function for training that outputs per example loss.eval_metrics (
Optional
[Mapping
[str
,Metric
]]) – Mapping of evaluation metric names toMetric
. These metrics are defined for single examples and will be consumed inevaluate_model()
.train_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for training.eval_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for evaluation.
- Return type:
- Returns:
Model
- fedjax.create_model_from_stax(stax_init, stax_apply, sample_shape, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None, input_key='x')[source]
Creates Model after applying defaults and stax specific preprocessing.
- Parameters:
stax_init (
Callable
[…,Any
]) – Initialization function returned fromstax.serial()
.stax_apply (
Callable
[…,Array
]) – Model forward_pass pass function returned from stax.serial.sample_shape (
Tuple
[int
, …]) – The expected shape of the input to the model.train_loss (
Callable
[[Mapping
[str
,Array
],Array
],Array
]) – Loss function for training that outputs per example loss.eval_metrics (
Optional
[Mapping
[str
,Metric
]]) – Mapping of evaluation metric names toMetric
. These metrics are defined for single examples and will be consumed inevaluate_model()
.train_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for training.eval_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for evaluation.input_key (
str
) – Key name for the input in batch mapping.
- Return type:
- Returns:
Model
- fedjax.evaluate_model(model, params, batches)[source]
Evaluates model for multiple batches and returns final results.
This is the recommended way to compute evaluation metrics for a given model.
- Parameters:
model (
Model
) – Model container.params (
Any
) – Pytree of model parameters to be evaluated.batches (
Iterable
[Mapping
[str
,Array
]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (seeClientDataset.padded_batch()
).
- Return type:
Dict
[str
,Array
]- Returns:
A dictionary of evaluation
Metric
results.
- fedjax.model_grad(model, regularizer=None)[source]
A standard gradient function derived from a model and an optional regularizer.
The scalar loss function being differentiated is simply:
mean(model’s per-example loss) + regularizer term
The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.
- fedjax.model_per_example_loss(model)[source]
Convenience function for constructing a per-example loss function from a model.
- fedjax.evaluate_average_loss(params, batches, rng, per_example_loss, regularizer=None)[source]
Evaluates the average per example loss over multiple batches.
- Parameters:
params (
Any
) – PyTree of model parameters to be evaluated.batches (
Iterable
[Mapping
[str
,Array
]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch).rng (
Array
) – Initial PRNGKey for making per_example_loss calls.per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – Per example loss function.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer function.
- Return type:
- Returns:
The average per example loss, plus the regularizer term when specified.
- class fedjax.ModelEvaluator(model)[source]
Evaluates model for each client dataset, either using global params, or per client params.
To evaluate a Model on a single dataset, use evaluate_model() instead.
- evaluate_global_params(params, clients)[source]
Evaluates batches from each client using global params.
- class fedjax.AverageLossEvaluator(per_example_loss, regularizer=None)[source]
Evaluates average loss for each client dataset, either using global params, or per client params.
The average loss is defined as the average per example loss, plus the regularizer term when specified. To evaluate average loss on a single dataset, use evaluate_average_loss() instead.
- evaluate_global_params(params, clients)[source]
Evaluates batches from each client using global params.
- fedjax.grad(per_example_loss, regularizer=None)[source]
A standard gradient function derived from per-example loss and an optional regularizer.
The scalar loss function being differentiated is simply:
mean(per-example loss) + regularizer term
The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.
- Parameters:
- Return type:
- Returns:
A function from (params, batch_example, rng) to gradients.
fedjax.aggregators
FedJAX aggregators.
- class fedjax.aggregators.Aggregator(init, apply)[source]
Interface for algorithms to aggregate.
This interface defines aggregator algorithms that are used at each round. Aggregator state contains any round specific parameters (e.g. number of bits) that will be passed from round to round. This state is initialized by init and passed as input into and returned as output from aggregate. We strongly recommend using fedjax.dataclass to define state as this provides immutability, type hinting, and works by default with JAX transformations.
The expected usage of Aggregator is as follows:
aggregator = mean_aggregator() state = aggregator.init() for i in range(num_rounds): clients_params_and_weights = compute_client_outputs(i) aggregated_params, state = aggregator.apply(clients_params_and_weights, state)
- init
Returns initial state of aggregator.
- Type:
Callable[[], Any]
- apply
Returns the new aggregator state and aggregated params.
- Type:
Callable[[Iterable[Tuple[bytes, Any, float]], Any], Tuple[Any, Any]]
Mean
Quantization
- fedjax.aggregators.uniform_stochastic_quantizer(num_levels, rng, encode_algorithm=None)[source]
Returns (weighted) mean of input uniformly quantized trees using the
uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.
- Parameters:
num_levels (
int
) – number of levels of quantization.rng (
Array
) – PRNGKey used for compression:encode_algorithm (
Optional
[str
]) – None or arithmetic
- Return type:
- Returns:
Compression aggregator.
fedjax.algorithms
Federated learning algorithm implementations.
AgnosticFedAvg implementation. |
|
Federated averaging implementation using fedjax.core. |
|
Federated hypothesis-based clustering implementation. |
|
Mime implementation. |
|
Mime Lite implementation. |
AgnosticFedAvg
AgnosticFedAvg implementation.
- Communication-Efficient Agnostic Federated Averaging
Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748
- class fedjax.algorithms.agnostic_fed_avg.ServerState(params, opt_state, domain_weights, domain_window)[source]
State of server for AgnosticFedAvg passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the server optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- domain_window
Sliding window keeping track of domain number of examples for the last window size rounds.
- Type:
List[jax.Array]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.agnostic_fed_avg.agnostic_federated_averaging(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, domain_batch_hparams, init_domain_weights, domain_learning_rate, domain_algorithm='eg', domain_window_size=1, init_domain_window=None, regularizer=None)[source]
Builds agnostic federated averaging.
Agnostic federated averaging requires input
fedjax.core.client_datasets.ClientDataset
examples to contain a feature named “domain_id”, which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1).- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for client dataset for training.domain_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for client dataset domain metrics calculation.init_domain_weights (
Sequence
[float
]) – Initial weights per domain that must sum to 1.domain_learning_rate (
float
) – Learning rate for domain weight update.domain_algorithm (
str
) – Algorithm used to update domain weights each round. One of ‘eg’, ‘none’.domain_window_size (
int
) – Size of sliding window keeping track of number of examples per domain over multiple rounds.init_domain_window (
Optional
[Sequence
[float
]]) – Initial values for domain window. Defaults to ones.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type:
- Returns:
FederatedAlgorithm.
- Raises:
ValueError – If
init_domain_weights
does not sum to 1 or ifinit_domain_weights
andinit_domain_window
are unequal lengths.
FedAvg
Federated averaging implementation using fedjax.core.
This is the more performant implementation that matches what would be used in
the fedjax.algorithms.fed_avg
. The key difference between this and
the basic version is the use of fedjax.core.for_each_client
- Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021. https://arxiv.org/abs/2003.00295
- class fedjax.algorithms.fed_avg.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the server optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams)[source]
Builds federated averaging.
- Parameters:
grad_fn (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Any
]) – A function from (params, batch_example, rng) to gradients. This can be created withfedjax.core.model.model_grad()
.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.
- Return type:
- Returns:
FederatedAlgorithm
HypCluster
Federated hypothesis-based clustering implementation.
- Three Approaches for Personalization with Applications to Federated Learning
Yishay Mansour, Mehryar Mohri, Jae Ro, Ananda Theertha Suresh https://arxiv.org/abs/2002.10619
- class fedjax.algorithms.hyp_cluster.ServerState(cluster_params, opt_states)[source]
State of server for HypCluster passed between rounds.
- cluster_params
A list of pytrees representing the server model parameters per cluster.
- Type:
List[Any]
- opt_states
A list of pytrees representing the server optimizer state per cluster.
- Type:
List[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.hyp_cluster.hyp_cluster(per_example_loss, client_optimizer, server_optimizer, maximization_batch_hparams, expectation_batch_hparams, regularizer=None)[source]
Federated hypothesis-based clustering algorithm.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch, rng) to a vector of per example loss values.client_optimizer (
Optimizer
) – Client side optimizer.server_optimizer (
Optimizer
) – Server side optimizer.maximization_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for the maximization step.expectation_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for the expectation step.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer.
- Return type:
- Returns:
A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that init() takes a list of cluster_params, which can be obtained using random_init(), ModelKMeansInitializer, or kmeans_init() in this module.
- fedjax.algorithms.hyp_cluster.kmeans_init(num_clusters, init_params, clients, trainer, train_batch_hparams, evaluator, eval_batch_hparams, rng)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
See
ModelKMeansInitializer
for a more convenient initializer from aModel
.Given a set of input clients, we train parameters for each client. The initial cluster parameters are chosen out of this set of client parameters. At a high level, the cluster parameters are selected by choosing the client parameters that are the “furthest away” from each other (greatest difference in loss).
- Parameters:
num_clusters (
int
) – Desired number of cluster centers.init_params (
Any
) – Initial params for training each client.clients (
Sequence
[Tuple
[bytes
,ClientDataset
,Array
]]) – Clients to train for generating candidate cluster center params.trainer (
ClientParamsTrainer
) – Client trainer for carry out client training.train_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for client training.evaluator (
AverageLossEvaluator
) – Average loss evaluator for evaluator clients on chosen cluster center params.eval_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for client evaluation.rng (
Array
) – RNG used for choosing the initial cluster center.
- Return type:
List
[Any
]- Returns:
Cluster center params.
- Raises:
ValueError if some input arguments are invalid. –
- fedjax.algorithms.hyp_cluster.random_init(num_clusters, init, rng)[source]
Randomly initializes cluster params.
- Return type:
List
[Any
]
- class fedjax.algorithms.hyp_cluster.ModelKMeansInitializer(model, client_optimizer, regularizer=None)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
This is a thin wrapper for initializing from a
Model
. Seekmeans_init()
for a more general version of the initializer, and details of the initialization algorithm.
- class fedjax.algorithms.hyp_cluster.HypClusterEvaluator(model, regularizer=None)[source]
Evaluates cluster params on a Model.
- __init__(model, regularizer=None)[source]
Initializes some reusable components.
Because we need to make multiple passes over each client during evaluation, the number of clients that can be evaluated at once is limited. Therefore multiple calls to
evaluate_clients()
are needed to evaluate a large federated dataset. We factor out some reusable components so that the same computation can be jit compiled.
Mime
Mime implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
- class fedjax.algorithms.mime.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the base optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.mime.mime(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None)[source]
Builds mime.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type:
- Returns:
FederatedAlgorithm
MimeLite
Mime Lite implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
Reuses fedjax.algorithms.mime.ServerState
- fedjax.algorithms.mime_lite.mime_lite(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None, client_delta_clip_norm=None)[source]
Builds mime lite.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.client_delta_clip_norm (
Optional
[float
]) – Maximum allowed global norm per client update. Defaults to no clipping.
- Return type:
- Returns:
FederatedAlgorithm
fedjax.datasets
fedjax datasets.
Federated cifar100. |
|
Federated EMNIST. |
|
Federated Shakespeare. |
|
Federated stackoverflow. |
CIFAR-100
Federated cifar100.
- fedjax.datasets.cifar100.load_data(mode='sqlite', cache_dir=None)[source]
Loads partially preprocessed cifar100 splits.
Features:
x: [N, 32, 32, 3] uint8 pixels.
y: [N] int32 labels in the range [0, 100).
Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:
import functools from fedjax.datasets import cifar100 # Load partially preprocessed splits. train, test = cifar100.load_data() # Preprocessing for training. train_for_train = train.preprocess_batch( functools.partial(preprocess_batch, is_train=True)) # Preprocessing for eval. train_for_eval = train.preprocess_batch( functools.partial(preprocess_batch, is_train=False)) test = test.preprocess_batch( functools.partial(preprocess_batch, is_train=False))
Features after this preprocessing:
x: [N, 32, 32, 3] float32 preprocessed pixels.
y: [N] int32 labels in the range [0, 100).
Alternatively, you can apply the same preprocessing as TensorFlow Federated following tff.simulation.baselines.cifar100.create_image_classification_task. For example,:
from fedjax.datasets import cifar100 train, test = cifar100.load_data() train = train.preprocess_batch(preprocess_batch_tff) test = test.preprocess_batch(preprocess_batch_tff)
Features after this preprocessing:
x: [N, 24, 24, 3] float32 preprocessed pixels.
y: [N] int32 labels in the range [0, 100).
Note:
preprocess_batch
andpreprocess_batch_tff
are just convenience wrappers aroundpreprocess_image()
andpreprocess_image_tff()
, respectively, for use withfedjax.FederatedData.preprocess_batch()
.- Parameters:
mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
A (train, test) tuple of federated data.
- fedjax.datasets.cifar100.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a cifar100 split.
Features:
image: [N, 32, 32, 3] uint8 pixels.
coarse_label: [N] int64 coarse labels in the range [0, 20).
label: [N] int64 labels in the range [0, 100).
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- fedjax.datasets.cifar100.preprocess_image(image, is_train)[source]
Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.
Preprocessing procedure and values taken from pytorch-cifar.
- Parameters:
image (
ndarray
) – [N, 32, 32, 3] uint8 pixels.is_train (
bool
) – Whether we are preprocessing for training or eval.
- Return type:
ndarray
- Returns:
Processed [N, 32, 32, 3] float32 pixels.
EMNIST
Federated EMNIST.
- fedjax.datasets.emnist.domain_id(client_id)[source]
Returns domain id for client id.
Domain ids are based on the NIST data source, where examples were collected from two sources: Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS). For more details, see the NIST documentation.
- Parameters:
client_id (
bytes
) – Client id of the format[16-byte hex hash]:f[4-digit integer]_[2-digit integer]
orf[4-digit integer]_[2-digit integer]
.- Return type:
int
- Returns:
Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).
- fedjax.datasets.emnist.load_data(only_digits=False, mode='sqlite', cache_dir=None)[source]
Loads processed EMNIST train and test splits.
Features:
x: [N, 28, 28, 1] float32 flipped image pixels.
y: [N] int32 classification label.
domain_id: [N] int32 domain id (see
domain_id()
).
- Parameters:
only_digits (
bool
) – Whether to only load the digits data.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
Train and test splits as FederatedData.
- fedjax.datasets.emnist.load_split(split, only_digits=False, mode='sqlite', cache_dir=None)[source]
Loads an unprocessed federated emnist split.
Features:
pixels: [N, 28, 28] float32 image pixels.
label: [N] int32 classification label.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.only_digits (
bool
) – Whether to only load the digits data.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
Shakespeare
Federated Shakespeare.
- fedjax.datasets.shakespeare.load_data(sequence_length=80, mode='sqlite', cache_dir=None)[source]
Loads preprocessed shakespeare splits.
Preprocessing is done using
fedjax.FederatedData.preprocess_client()
andpreprocess_client()
.Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)
- Parameters:
sequence_length (
int
) – The fixed sequence length after preprocessing.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
A (train, held_out, test) tuple of federated data.
- fedjax.datasets.shakespeare.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a shakespeare split.
Features:
snippets: [N] bytes array of snippet text.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- fedjax.datasets.shakespeare.preprocess_client(client_id, examples, sequence_length)[source]
Turns snippets into sequences of integer labels.
Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)
All snippets in a client dataset are first joined into a single sequence (with BOS/EOS added), and then split into pairs of sequence_length chunks for language model training. For example, with sequence_length=3, [b’ABCD’, b’E’] becomes:
Input sequences: [[BOS, A, B], [C, D, EOS], [BOS, E, PAD]] Output seqeunces: [[A, B, C], [D, EOS, BOS], [E, EOS, PAD]]
Note: This is not equivalent to the TensorFlow Federated text generation tutorial (The processing logic there loses ~1/sequence_length portion of the tokens).
- Parameters:
client_id (
bytes
) – Not used.examples (
Mapping
[str
,ndarray
]) – Unprocessed examples (e.g. from load_split()).sequence_length (
int
) – The fixed sequence length after preprocessing.
- Return type:
Mapping
[str
,ndarray
]- Returns:
Processed examples.
Stack Overflow
Federated stackoverflow.
- fedjax.datasets.stackoverflow.load_data(mode='sqlite', cache_dir=None)[source]
Loads partially preprocessed stackoverflow splits.
Features:
domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).
tokens: [N] bytes array. Space separated list of tokens.
To convert tokens into padded/truncated integer labels, use a StackoverflowTokenizer. For example,:
from fedjax.core.datasets import stackoverflow # Load partially preprocessed splits. train, held_out, test = stackoverflow.load_data() # Apply tokenizer during batching. tokenizer = stackoverflow.StackoverflowTokenizer() train_max_length, eval_max_length = 20, 30 train_for_train = train.preprocess_batch( tokenizer.as_preprocess_batch(train_max_length)) train_for_eval = train.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length)) held_out = held_out.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length)) test = test.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length))
Features after tokenization:
domain_id: Same as before.
x: [N, max_length] int32 array of padded/truncated input labels.
y: [N, max_length] int32 array of padded/truncated output labels.
- Parameters:
mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
,FederatedData
]- Returns:
A (train, held_out, test) tuple of federated data.
- fedjax.datasets.stackoverflow.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a stackoverflow split.
All bytes arrays are stored with dtype=np.object.
Features:
creation_date: [N] bytes array. Textual timestamp, e.g. b’2018-02-28 19:06:18.34 UTC’.
title: [N] bytes array. The title of a post.
score: [N] int64 array. The score of a post.
tags: [N] bytes array. ‘|’ separated list of tags, e.g. b’mysql|join’.
tokens: [N] bytes array. Space separated list of tokens.
type: [N] bytes array. Either b’question’ or b’answer’.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- class fedjax.datasets.stackoverflow.StackoverflowTokenizer(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]
Tokenizer for the tokens feature in stackoverflow.
See
load_data()
for examples.- __init__(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]
Initializes a tokenizer.
- Parameters:
vocab (
Optional
[List
[str
]]) – Optional vocabulary. If specified, default_vocab_size is ignored. If None, default_vocab_size is used to load the standard vocabulary. This vocabulary should NOT have special tokens PAD, EOS, BOS, and OOV. The special tokens are added and handled automatically by the tokenizer. The preprocessed examples will have vocabulary size len(vocab) + 3 + num_oov_buckets.default_vocab_size (
Optional
[int
]) – Number of words in the default vocabulary. This is only used when vocab is not specified. The preprocessed examples will have vocabulary size default_vocab_size + 3 + num_oov_buckets with 3 special labels: 0 (PAD), 1 (BOS), 2 (EOS), and num_oov_buckets OOV labels starting at default_vocab_size + 3.num_oov_buckets (
int
) – Number of out of vocabulary buckets.
fedjax.models
Model implementations for FedJAX experimental API.
EMNIST models. |
|
Shakespeare recurrent models. |
|
Stack Overflow recurrent models. |
|
Toy regression models. |
EMNIST
EMNIST models.
- class fedjax.models.emnist.ConvDropoutModule(num_classes)[source]
Bases:
Module
Custom haiku module for CNN with dropout.
This must be defined as a custom hk.Module because only a single positional argument is allowed when using hk.Sequential.
- fedjax.models.emnist.create_conv_model(only_digits=False)[source]
Creates EMNIST CNN model with dropout with haiku.
Matches the model used in:
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295
- Parameters:
only_digits (
bool
) – Whether to use only digit classes [0-9] or include lower and upper case characters for a total of 62 classes.- Return type:
- Returns:
Model
- fedjax.models.emnist.create_dense_model(only_digits=False, hidden_units=200)[source]
Creates EMNIST dense net with haiku.
- Return type:
Shakespeare
Shakespeare recurrent models.
- fedjax.models.shakespeare.create_lstm_model(vocab_size=86, embed_size=8, lstm_hidden_size=256, lstm_num_layers=2)[source]
Creates LSTM language model.
Character-level LSTM for Shakespeare language model. Defaults to the model used in:
- Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629
- Parameters:
vocab_size (
int
) – The number of possible output characters. This does not include special tokens like PAD, BOS, EOS, or OOV.embed_size (
int
) – Embedding size for each character.lstm_hidden_size (
int
) – Hidden size for LSTM cells.lstm_num_layers (
int
) – Number of LSTM layers.
- Return type:
- Returns:
Model.
Stack Overflow
Stack Overflow recurrent models.
- fedjax.models.stackoverflow.create_lstm_model(vocab_size=10000, embed_size=96, lstm_hidden_size=670, lstm_num_layers=1, share_input_output_embeddings=False, expected_length=None)[source]
Creates LSTM language model.
Word-level language model for Stack Overflow. Defaults to the model used in:
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295
- Parameters:
vocab_size (
int
) – The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV.embed_size (
int
) – Embedding size for each word.lstm_hidden_size (
int
) – Hidden size for LSTM cells.lstm_num_layers (
int
) – Number of LSTM layers.share_input_output_embeddings (
bool
) – Whether to share the input embeddings with the output logits.expected_length (
Optional
[float
]) – Expected average sentence length used to scale the training loss down by 1. / expected_length. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling.
- Return type:
- Returns:
Model.
Toy Regression
Toy regression models.
- fedjax.models.toy_regression.create_regression_model()[source]
Creates toy regression model.
Matches the model used in:
- Communication-Efficient Agnostic Federated Averaging
Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748
- Return type:
- Returns:
Model
fedjax.training
FedJAX training utilities.
Federated experiment
- fedjax.training.run_federated_experiment(algorithm, init_state, client_sampler, config, periodic_eval_fn_map=None, final_eval_fn_map=None)[source]
Runs the training loop of a federated algorithm experiment.
- Parameters:
algorithm (
FederatedAlgorithm
) – Federated algorithm to use.init_state (
Any
) – Initial server state.client_sampler (
ClientSampler
) – Sampler for training clients.config (
FederatedExperimentConfig
) – FederatedExperimentConfig configurations.periodic_eval_fn_map (
Optional
[Mapping
[str
,Any
]]) – Mapping of name to evaluation functions that are run repeatedly over multiple federated training rounds. The frequency is defined in _FederatedExperimentConfig.eval_frequency.final_eval_fn_map (
Optional
[Mapping
[str
,EvaluationFn
]]) – Mapping of name to evaluation functions that are run at the very end of federated training. Typically, full test evaluation functions will be set here.
- Return type:
Any
- Returns:
Final state of the input federated algortihm after training.
- class fedjax.training.FederatedExperimentConfig(root_dir: str, num_rounds: int, checkpoint_frequency: int = 0, num_checkpoints_to_keep: int = 1, eval_frequency: int = 0)[source]
Common configurations of a federated experiment.
- Attribues:
root_dir: Root directory for experiment outputs (e.g. metrics). num_rounds: Number of federated training rounds. checkpoint_frequency: Checkpoint frequency in rounds. If <= 0, no checkpointing is done. num_checkpoints_to_keep: Maximum number of checkpoints to keep. eval_frequency: Evaluation frequency in rounds. If <= 0, no evaluation is done.
- class fedjax.training.EvaluationFn[source]
Evaluation function that are only fed state at every call.
Typically used for full evaluation or evaluation on sampled clients from a test set.
- class fedjax.training.ModelFullEvaluationFn(fd, model, batch_hparams)[source]
Bases:
EvaluationFn
Evaluation on an entire federated dataset using the centralized model.
- class fedjax.training.ModelSampleClientsEvaluationFn(client_sampler, model, batch_hparams)[source]
Bases:
EvaluationFn
Evaluation on sampled clients using the centralized model.
The state to be evaluated must contain a params field.
- class fedjax.training.TrainClientsEvaluationFn[source]
Evaluation function that are fed training clients at every call.
Typically used for evaluation on the training clients used in a step.
- class fedjax.training.ModelTrainClientsEvaluationFn(model, batch_hparams)[source]
Bases:
TrainClientsEvaluationFn
Evaluation on training clients using the centralized model.
The state to be evaluated must contain a params field.
- fedjax.training.set_tf_cpu_only()[source]
Restricts TensorFlow device visibility to only CPU.
TensorFlow is only used for data loading, so we prevent it from allocating GPU/TPU memory.
- fedjax.training.load_latest_checkpoint(root_dir)[source]
Loads latest checkpoint and round number.
- Return type:
Optional
[Tuple
[Any
,int
]]
- fedjax.training.save_checkpoint(root_dir, state, round_num=0, keep=1)[source]
Saves checkpoint and cleans up old checkpoints.
- class fedjax.training.Logger(root_dir=None)[source]
Class to encapsulate tf.summary.SummaryWriter logging logic.
- log(writer_name, metric_name, metric_value, round_num)[source]
Records metric using specified summary writer.
Logs at INFO verbosity. Also, if root_dir is set and metric_value is: - a scalar value, convertible to a float32 Tensor, writes scalar summary - a vector, convertible to a float32 Tensor, writes histogram summary
- Parameters:
writer_name (
str
) – Name of summary writer.metric_name (
str
) – Name of metric to log.metric_value (
Any
) – Value of metric to log.round_num (
int
) – Round number to log.
Tasks
Registry of standard tasks.
Each task is represented as a (train federated data, test federated data, model) tuple.
- training.ALL_TASKS = ('EMNIST_CONV', 'EMNIST_LOGISTIC', 'EMNIST_DENSE', 'SHAKESPEARE_CHARACTER', 'STACKOVERFLOW_WORD', 'CIFAR100_LOGISTIC')
- fedjax.training.get_task(name, mode='sqlite', cache_dir=None)[source]
Gets a standard task.
- Parameters:
name (
str
) – Name of the task to get. Must be one of fedjax.training.ALL_TASKS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
,Model
]- Returns:
(train federated data, test federated data, model) tuple.
Structured flags
Structured flags commonly used in experiment binaries.
Structured flags are often used to construct complex structures via multiple simple flags (e.g. an optimizer can be created by controlling learning rate and other hyper parameters).
- class fedjax.training.structured_flags.BatchHParamsFlags(name=None, default_batch_size=128)[source]
Constructs BatchHParams from flags.
- class fedjax.training.structured_flags.FederatedExperimentConfigFlags(name=None)[source]
Constructs FederatedExperimentConfig from flags.
- class fedjax.training.structured_flags.NamedFlags(name)[source]
A group of flags with an optional named prefix.
- class fedjax.training.structured_flags.OptimizerFlags(name=None, default_optimizer='sgd')[source]
Constructs a fedjax.Optimizer from flags.
- class fedjax.training.structured_flags.PaddedBatchHParamsFlags(name=None, default_batch_size=128)[source]
Constructs PaddedBatchHParams from flags.