FedJAX documentation

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.

Federated datasets

Open in Colab

This tutorial introduces datasets in FedJAX and how to work with them. By completing this tutorial, we’ll learn about the best practices for working with datasets.

NOTE: For datasets, we operate over NumPy arrays NOT JAX arrays.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
# !pip install tensorflow_datasets
import functools
import itertools
import fedjax
import numpy as np

What are datasets in federated learning?

In the context of federated learning (FL), data is decentralized across clients, with each client having their own local set of examples. In light of this, we refer to two levels of organization for datasets:

  • Federated dataset: A collection of clients, each with their own local datasets and metadata

  • Client dataset: The set of local examples for a particular client

We can think of federated data as a mapping from client ids to client datasets and client datasets as a list of examples.

federated_data = {
  'client0': ['a', 'b', 'c'],
  'client1': ['d', 'e'],
}

Federated datasets

FedJAX comes packaged with multiple federated datasets, and we will look at the Shakespeare dataset as an example. The Shakespeare dataset is created from The Complete Works of Shakespeare, by treating each character in the play as a “client”, and their dialogue lines as the examples.

FedJAX organizes federated datasets as Python modules. load_data() from a dataset module loads all predefined splits together as fedjax.FederatedData objects, the interface for accessing all federated datasets. In the case of the Shakespeare dataset, load_data() returns two splits: train and test.

# We cap max sentence length to 8.
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_train.sqlite' to '/tmp/.cache/fedjax/shakespeare_train.sqlite'
100%, elapsed: 0s
Downloading 'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_test.sqlite' to '/tmp/.cache/fedjax/shakespeare_test.sqlite'
100%, elapsed: 0s

fedjax.FederatedData provides methods for accessing metadata about the federated dataset, like the total number of clients, client ids, and number of examples for each client.

As seen in the output, there are 715 total clients in the Shakespeare dataset. Each client has a unique client ID that can be used to query metadata about that client such as the number of examples that client has.

print('num_clients =', train_fd.num_clients())

# train_fd.client_ids() is a generator of client ids.
# itertools has efficient and convenient functions for working with generators.
for client_id in itertools.islice(train_fd.client_ids(), 3):
  print('client_id =', client_id)
  print('# examples =', train_fd.client_size(client_id))
num_clients = 715
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 5
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 13
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 15

As we notice, the client ids start with a random set of bits. This is to ensure that one can easily slice a FederatedData to obtain a random subset. While the client ids returned above look sorted, there is no such guarantee in general.

# Slicing are based on the lexicographic order of client ids.
train_fd_0 = train_fd.slice(start=b'0', stop=b'1')
print('num_clients whose id starts with 0 =', train_fd_0.num_clients())
num_clients whose id starts with 0 = 47

Client datasets

We can query the dataset for a client from a federated dataset using their client ID and fedjax.FederatedData.get_client(). The output of fedjax.FederatedData.get_client() is a fedjax.ClientDataset. A fedjax.ClientDataset object

  • Stores all the examples for a given client and any preprocessing that should be applied on batches.

  • Provides methods for batching, shuffling, and iterating over preprocessed examples.

In other words, ClientDataset = examples + preprocessor.

client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)
print(client_dataset)
<fedjax.core.client_datasets.ClientDataset object at 0x7ff730f5afd0>

FedJAX assumes that an individual client dataset is small and can easily fit in memory. This assumption is also reflected in many of FedJAX’s design decisions. The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).

FedJAX uses a column based representation when loading a dataset into memory.

  • Each column is a NumPy array x of rank at least 1, where x[i, ...] is the value of this feature for the i-th example.

  • The complete set of examples is a dict-like object, from str feature names, to the corresponding column values.

Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch.

In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.

The preprocessed examples of a single client can be viewed by calling all_examples().

client_dataset.all_examples()
{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12],
        [75, 54, 74, 19, 16, 66, 47,  3],
        [16, 67,  8, 67, 71, 47,  7, 61],
        [16, 14,  4, 67, 47, 16, 48, 67],
        [84, 67, 47,  7, 67, 48, 16, 12],
        [78, 29, 75, 78, 33, 16, 66, 47],
        [ 3, 16, 75, 73, 29, 11, 75, 76],
        [32, 19, 65, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [ 7,  4, 67, 16, 51, 48, 68,  7],
        [26, 47, 27, 42, 16,  7,  4, 67],
        [72, 16, 48, 67, 27, 23, 71, 67],
        [65, 29, 79, 76, 51, 74, 12, 75],
        [54, 74, 19, 16, 66, 47,  3, 16],
        [67,  8, 67, 71, 47,  7, 61, 16],
        [14,  4, 67, 47, 16, 48, 67, 84],
        [67, 47,  7, 67, 48, 16, 12, 78],
        [29, 75, 78, 33, 16, 66, 47,  3],
        [16, 75, 73, 29, 11, 75, 76, 32],
        [19, 65, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0]], dtype=int32)}

For Shakespeare, we are training a character-level language model, where the task is next character prediction, so the features are:

  • x is a list of right-shifted sentences, e.g. sentence[:-1]

  • y is a list of left-shifted sentences, e.g. sentence[1:]

This way, y[i][j] corresponds to the next character after x[i][j].

examples = client_dataset.all_examples()
print('x', examples['x'][0])
print('y', examples['y'][0])
x [ 1 55 67 84 67 47  7 67]
y [55 67 84 67 47  7 67 48]

However, you probably noticed that x and y are arrays of integers not text. This is because fedjax.datasets.shakespeare.load_data() does some minimal preprocessing, such as a simple character look up that mapped characters to integer IDs. Later, we’ll go over how this preprocessing was applied and how to add your own custom preprocessing.

For comparison, here’s the unprocessed version of the same client dataset.

# Unlike load_data(), load_split() always loads a single unprocessed split.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n                     with IMOGEN\n'],
       dtype=object)}

Accessing client datasets from fedjax.FederatedData

fedjax.FederatedData.get_client() works well for querying data for a single client for exploring the dataset. However, we often want to query for multiple client datasets at the same time. In most FL algorithms, tens to hundreds of clients particpate in each training round. If we are not careful, our code can spend a lot of time loading data, leaving the accelerators (GPU or TPU) to idle.

In light of this, fedjax.FederatedData offers more efficient methods for querying multiple client datasets.

We’ll go through each access method from the MOST efficient to the LEAST efficient.

clients() and shuffled_clients()

Fastest sequential read friendly access. As we stated earlier, the client ids are appended with random bits. Hene, even sequential reads will go over clients in a pseudo-random order.

# clients() and shuffled_clients() are sequential read friendly.
clients = train_fd.clients()
shuffled_clients = train_fd.shuffled_clients(buffer_size=100, seed=0)
print('clients =', clients)
print('shuffled_clients =', shuffled_clients)
clients = <generator object SQLiteFederatedData.clients at 0x7ff730f3d950>
shuffled_clients = <generator object SQLiteFederatedData.shuffled_clients at 0x7ff7310a8950>

They are generators, so we iterate over them to get the individual client datasets as tuples of (client_id, client_dataset).

clients() returns clients in an unspecified deterministic order. It is useful for going over the entire federated dataset for evaluation.

# We use itertools.islice to select first three clients.
for client_id, client_dataset in itertools.islice(clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 53
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 234
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 79

shuffled_clients() provides a stream of infinitely repeating shuffled client datasets, using buffered shuffling. It is suitable for training rounds where a nearly random shuffling is good enough.

print('shuffled_clients()')
for client_id, client_dataset in itertools.islice(shuffled_clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
shuffled_clients()
client_id = b'0a18c2501d441fef:THE_TRAGEDY_OF_KING_LEAR_FLUTE'
# examples = 115
client_id = b'136c5586b7271525:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GLOUCESTER'
# examples = 3804
client_id = b'0d642a9b4bb27187:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MESSENGER'
# examples = 775

get_clients()

Slower than clients() since it requires random read but uses prefetching to hide the latency of random read access. This also returns a generator of tuples of (client_id, client_dataset), in the order of the input client_ids.

client_ids = [
    b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN',
    b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN',
    b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
]
for client_id, client_dataset in train_fd.get_clients(client_ids):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483
client_id = b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN'
# examples = 5
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
# examples = 19

get_client()

Slowest way of accessing client datasets. We usually reserve this method only for interactive exploration of a small number of clients.

client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
print('client_id =', client_id)
print('# examples =', len(train_fd.get_client(client_id)))
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 483

Batching client datasets

Next we’ll go over different methods of iterating over a fedjax.ClientDataset as batched examples. All the following methods can be invoked in 2 ways:

  • Using a hyperparams object: This is the recommended way in library code. batch_fn(hparams).

  • Using keyword arguments: The keyword arguments are used to construct a new hyperparams object, or override an existing one. batch_fn(batch_size=2) or batch_fn(hparams, batch_size=2) to override batch_size.

For the most part, we’ll use method 2 for this colab, but method 1 is more suitable for writing library code.

client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
client_dataset = train_fd.get_client(client_id)

batch() for illustrations

Produces preprocessed batches in a fixed sequential order.

The final batch may contain fewer than batch_size examples. If used directly, that may result in a large number of JIT recompilations. Therefore we should use padded_batch() or shuffle_repeat_batch() instead in most scenarios.

Note here we are not talking about padding within an example, often done in processing sequence examples (e.g. the 0 labels below), but rather padding with “empty” examples in a batch.

batches = list(client_dataset.batch(batch_size=8))
batches[-1]
{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0]], dtype=int32)}

padded_batch() for evaluation

Produces preprocessed padded batches in a fixed sequential order for evaluation.

When the number of examples in the dataset is not a multiple of batch_size, the final batch may be smaller than batch_size. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled by num_batch_size_buckets.

# use list() to consume generator and store in memory.
padded_batches = list(client_dataset.padded_batch(batch_size=8, num_batch_size_buckets=3))
print('# batches =', len(padded_batches))
padded_batches[0]
# batches = 3
{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [ 7,  4, 67, 16, 51, 48, 68,  7],
        [26, 47, 27, 42, 16,  7,  4, 67],
        [72, 16, 48, 67, 27, 23, 71, 67],
        [65, 29, 79, 76, 51, 74, 12, 75]], dtype=int32),
 '__mask__': array([ True,  True,  True,  True,  True,  True,  True,  True])}

All batches contain an extra bool feature keyed by '__mask__'. batch['__mask__'][i] tells us whether the i-th example in this batch is an actual example (batch['__mask__'][i] == True), or a padding example (batch['__mask__'][i] == False).

We repeatedly halve the batch size up to num_batch_size_buckets - 1 times, until we find the smallest one that is also >= the size of the final batch. Therefore if batch_size < 2^num_batch_size_buckets, fewer bucket sizes will be actually used. This will be seen when we look at the final batch that only has 4 examples when the original batch size was 8.

padded_batches[-1]
{'__mask__': array([ True,  True,  True, False]),
 'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32)}

shuffle_repeat_batch() for training

Produces preprocessed batches in a shuffled and repeated order for training.

Shuffling is done without replacement, therefore for a dataset of N examples, the first ceil(N/batch_size) batches are guarranteed to cover the entire dataset. Unlike batch() or padded_batch(), batches from shuffle_repeat_batch() always contain exactly batch_size examples. Also unlike TensorFlow, that holds even when drop_remainder=False.

By default the iteration stops after the first epoch.

print('# batches')
len(list(client_dataset.shuffle_repeat_batch(batch_size=8)))
# batches
3

The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder) combination:

If both num_epochs and num_steps are None, the shuffle-repeat process continues forever.

infinite_bs = client_dataset.shuffle_repeat_batch(
    batch_size=8, num_epochs=None, num_steps=None)
for i, b in zip(range(6), infinite_bs):
  print(i)
0
1
2
3
4
5

If num_epochs is set and num_steps is None, as few batches as needed to go over the dataset this many passes are produced. Further,

  • If drop_remainder is False (the default), the final batch is filled with additionally sampled examples to contain batch_size examples.

  • If drop_remainder is True, the final batch is dropped if it contains fewer than batch_size examples. This may result in examples being skipped when num_epochs=1.

print('# batches w/ drop_remainder=False')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None))))
print('# batches w/ drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None, drop_remainder=True))))
# batches w/ drop_remainder=False
3
# batches w/ drop_remainder=True
2

If num_steps is set and num_steps is None, exactly this many batches are produced. drop_remainder has no effect in this case.

print('# batches w/ num_steps set and drop_remainder=True')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=None, num_steps=3, drop_remainder=True))))
# batches w/ num_steps set and drop_remainder=True
3

If both num_epochs and num_steps are set, the fewer number of batches between the two conditions are produced.

print('# batches w/ num_epochs and num_steps set')
print(len(list(client_dataset.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=6))))
# batches w/ num_epochs and num_steps set
3

If reproducible iteration order is desired, a fixed seed can be used. When seed is None, repeated iteration over the same object may produce batches in a different order.

# Random shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=None))[0])
# Fixed shuffling.
print(list(client_dataset.shuffle_repeat_batch(batch_size=2, seed=0))[0])
{'x': array([[78, 29, 75, 78, 33, 16, 66, 47],
       [ 7, 26, 47, 27, 42, 16,  7,  4]], dtype=int32), 'y': array([[29, 75, 78, 33, 16, 66, 47,  3],
       [26, 47, 27, 42, 16,  7,  4, 67]], dtype=int32)}
{'x': array([[16, 14,  4, 67, 47, 16, 48, 67],
       [48, 16, 13, 32, 33, 14, 11, 78]], dtype=int32), 'y': array([[14,  4, 67, 47, 16, 48, 67, 84],
       [16, 13, 32, 33, 14, 11, 78, 76]], dtype=int32)}

Preprocessing

Going from the raw features to features in a batch of examples ready for use in training/evalution often requires a few steps of preprocessing. Sometimes, we also want to add new preprocessing to an existing FederatedData.

Before going into the details of preprocessing, please note once again that all dataset related operations should be implemented in standard NumPy, not in JAX.

Preprocessing a batch of examples

The easiest way to add an additional preprocessing step is by appending a function that transforms a batch of examples to a FederatedData’s list of preprocessing transformations on batches.

Below, we add a new z feature that stores the parity of y for our Shakespeare examples.

# A preprocessing function should return a new dict of examples instead of
# modifying its input.
def parity_feature(examples):
  return {'z': examples['y'] % 2, **examples}

# preprocess_batch returns a new FederatedData object that has one more
# preprocessing step at the very end than the original.
train_fd_z = train_fd.preprocess_batch(parity_feature)
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
next(iter(train_fd_z.get_client(client_id).padded_batch(batch_size=4)))
{'z': array([[1, 1, 0, 1, 1, 1, 1, 0],
        [0, 1, 0, 1, 0, 1, 0, 0],
        [0, 1, 1, 0, 0, 1, 1, 0],
        [1, 1, 1, 0, 1, 1, 1, 0]], dtype=int32),
 'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16]], dtype=int32),
 '__mask__': array([ True,  True,  True,  True])}

Preprocessing at the client level

Sometimes we also need to do some preprocessing for the entire client dataset. For example, in the Shakespeare dataset, the raw features are just text strings, so we need to turn them into a NumPy array of chunks of integer labels just to be able to meaningfully batch them at all.

In most circumstances, adding a preprocessing step at the client level is unnecessary, and should be avoided, because the new preprocessing step added by preprocess_client is inserted into the middle of a chain of steps, just before all the batch level preprocessing steps registered by preprocess_batch. If not done carefully, a custom client level preprocessing can easily break the preprocessing chain.

Nevertheless, here’s an example of client level processing for certain rare cases.

# Load unpreprocessed data.
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_fd.get_client(client_id).all_examples()
Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n                     with IMOGEN\n'],
       dtype=object)}

The actual client level preprocessing in the Shakespeare dataset is a bit involved, so let’s do something simpler: we shall join all the snippets, and then split and pad the integer byte values into bounded length sequences.

# We don't actually need client_id, but `FederatedData` supplies it so that
# different processing based on clients can be done.
def truncate_and_cast(client_id, examples, max_length=10):
  labels = list(b''.join(examples['snippets']))
  num_sequences = (len(labels) + max_length - 1) // max_length
  padded = np.zeros((num_sequences, max_length), dtype=np.int32)
  for i in range(num_sequences):
    chars = labels[i * max_length:(i + 1) * max_length]
    padded[i, :len(chars)] = chars
  return {'snippets': padded}


partial_fd = raw_fd.preprocess_client(truncate_and_cast)
partial_fd.get_client(client_id).all_examples()
{'snippets': array([[ 82, 101,  45, 101, 110, 116, 101, 114,  32,  80],
        [ 79,  83,  84,  72,  85,  77,  85,  83,  44,  32],
        [ 97, 110, 100,  32, 115, 101,  99, 111, 110, 100],
        [115,  32, 116, 104, 101,  32,  66, 114, 105, 116],
        [111, 110, 115,  59,  32, 116, 104, 101, 121,  32],
        [114, 101, 115,  99, 117, 101,  10,  67,  89,  77],
        [ 66,  69,  76,  73,  78,  69,  44,  32,  97, 110],
        [100,  32, 101, 120, 101, 117, 110, 116,  46,  32],
        [ 84, 104, 101, 110,  32, 114, 101,  45, 101, 110],
        [116, 101, 114,  32,  76,  85,  67,  73,  85,  83],
        [ 32,  97, 110, 100,  32,  73,  65,  67,  72,  73],
        [ 77,  79,  44,  10,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32, 119, 105, 116, 104,  32],
        [ 73,  77,  79,  71,  69,  78,  10,   0,   0,   0]], dtype=int32)}

Now, we can add another batch level preprocessor to produce x and y labels.

def snippets_to_xy(examples):
  snippets = examples['snippets']
  return {'x': snippets[:, :-1], 'y': snippets[:, 1:]}


partial_fd.preprocess_batch(snippets_to_xy).get_client(client_id).all_examples()
{'x': array([[ 82, 101,  45, 101, 110, 116, 101, 114,  32],
        [ 79,  83,  84,  72,  85,  77,  85,  83,  44],
        [ 97, 110, 100,  32, 115, 101,  99, 111, 110],
        [115,  32, 116, 104, 101,  32,  66, 114, 105],
        [111, 110, 115,  59,  32, 116, 104, 101, 121],
        [114, 101, 115,  99, 117, 101,  10,  67,  89],
        [ 66,  69,  76,  73,  78,  69,  44,  32,  97],
        [100,  32, 101, 120, 101, 117, 110, 116,  46],
        [ 84, 104, 101, 110,  32, 114, 101,  45, 101],
        [116, 101, 114,  32,  76,  85,  67,  73,  85],
        [ 32,  97, 110, 100,  32,  73,  65,  67,  72],
        [ 77,  79,  44,  10,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32, 119, 105, 116, 104],
        [ 73,  77,  79,  71,  69,  78,  10,   0,   0]], dtype=int32),
 'y': array([[101,  45, 101, 110, 116, 101, 114,  32,  80],
        [ 83,  84,  72,  85,  77,  85,  83,  44,  32],
        [110, 100,  32, 115, 101,  99, 111, 110, 100],
        [ 32, 116, 104, 101,  32,  66, 114, 105, 116],
        [110, 115,  59,  32, 116, 104, 101, 121,  32],
        [101, 115,  99, 117, 101,  10,  67,  89,  77],
        [ 69,  76,  73,  78,  69,  44,  32,  97, 110],
        [ 32, 101, 120, 101, 117, 110, 116,  46,  32],
        [104, 101, 110,  32, 114, 101,  45, 101, 110],
        [101, 114,  32,  76,  85,  67,  73,  85,  83],
        [ 97, 110, 100,  32,  73,  65,  67,  72,  73],
        [ 79,  44,  10,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32,  32,  32,  32,  32,  32],
        [ 32,  32,  32,  32, 119, 105, 116, 104,  32],
        [ 77,  79,  71,  69,  78,  10,   0,   0,   0]], dtype=int32)}

In memory federated datasets

In many scenarios, it is desirable to create a small custom federated dataset from a collection of NumPy arrays for quick experimentation. FedJAX provides fedjax.InMemoryFederatedData to create small custom datasets. fedjax.InMemoryFederatedData takes a dictionary of numpy examples keyed by client id and creates a fedjax.FederatedData that is compatible with the rest of the library. We illustrate this below with a simple example.

# Obtain MNIST dataset from tensorflow and convert to numpy format.
import tensorflow_datasets as tfds
(ds_train, ds_test) = tfds.load('mnist',
                                split=['train', 'test'],
                                shuffle_files=True,
                                as_supervised=True,
                                with_info=False)
features, labels = list(ds_train.batch(60000).as_numpy_iterator())[0]
print('features shape', features.shape)
print('labels shape', labels.shape)

# Randomly split dataset into 100 clients and load them to a dictionary.
indices = np.random.randint(100, size=60000)
client_id_to_dataset_mapping = {}
for i in range(100):
  client_id_to_dataset_mapping[i] = {'x': features[indices==i, :, : , :],
                                     'y': labels[indices==i]}

# Create fedjax.InMemoryDataset.
iid_mnist_federated_data = fedjax.InMemoryFederatedData(
    client_id_to_dataset_mapping)

print('number of clients in iid_mnist_data',
      iid_mnist_federated_data.num_clients())
features shape (60000, 28, 28, 1)
labels shape (60000,)
number of clients in iid_mnist_data 100

Recap

In this tutorial, we have covered the following:

  1. Using fedjax.FederatedData.

  2. Different ways of batching client datasets.

  3. Different ways of processing client datasets.

  4. Creating small custom federated datasets.

Working with models in FedJAX

Open in Colab

In this chapter, we will learn about fedjax.Model. This notebook assumes you already have finished the “Datasets” chapter. We first overview centralized training and evaluation with fedjax.Model and then describe how to add new neural architectures and specify additional evaluation metrics.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
import itertools

import jax
import jax.numpy as jnp
from jax.example_libraries import stax

import fedjax

Centralized training & evaluation with fedjax.Model

Most federated learning algorithms are built upon common components from standard centralized learning. fedjax.Model holds these common components. In centralized learning, we are mostly concerned with two tasks:

  • Training: We want to optimize our model parameters on the training dataset.

  • Evaluation: We want to know the values of evaluation metrics (e.g. accuracy) of the current model parameters on a test dataset.

Let’s first see how we can carry out these two tasks on the EMNIST dataset with fedjax.Model.

# Load train/test splits of the EMNIST dataset.
train, test = fedjax.datasets.emnist.load_data()

# As a start, let's simply use a logistic regression model.
model = fedjax.models.emnist.create_logistic_model()

Random initialization, the JAX way

To start training, we need some randomly initialized parameters. In JAX, pseudo random number generation works slightly differently. For now, it is sufficient to know we call jax.random.PRNGKey() to seed the random number generator. JAX has a detailed introduction on this topic, if you are interested.

To create the initial model parameters, we simply call fedjax.Model.init() with a PRNGKey.

params_rng = jax.random.PRNGKey(0)
params = model.init(params_rng)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Here are our initial model parameters. With the same PRNGKey, we will always get the same random initialization. There are 2 parameters in our model, the weights w, and the bias b. They are organized into a FlapMapping, but in general any PyTree can be used to store model parameters.

params
FlatMapping({
  'linear': FlatMapping({
              'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0.], dtype=float32),
              'w': DeviceArray([[-0.04067196,  0.02348138, -0.0214883 , ...,  0.01055492,
                                 -0.06988288, -0.02952586],
                                [-0.03985253, -0.03804361,  0.01401524, ...,  0.02281437,
                                 -0.01771905,  0.06676884],
                                [ 0.00098182, -0.00844628,  0.01303554, ..., -0.05299249,
                                  0.01777634, -0.0006488 ],
                                ...,
                                [-0.05691862,  0.05192501,  0.01588603, ...,  0.0157204 ,
                                 -0.01854135,  0.00297953],
                                [ 0.01680706,  0.05579231,  0.0459589 , ...,  0.01990358,
                                 -0.01944044, -0.01710149],
                                [-0.00880739,  0.04229043,  0.00998938, ..., -0.00633441,
                                 -0.04824542,  0.01395545]], dtype=float32),
            }),
})

Evaluating model parameters

Before we start training, let’s first see how our initial parameters fare on the train and test sets. Unsurprisingly, they do not do very well. We evaluate using the fedjax.evaluate_model() which takes in model, parameters, and datasets which are batched. As noted in the dataset tutorial, we batch using fedjax.padded_batch_federated_data() for efficiency. fedjax.padded_batch_federated_data() is very similar to fedjax.ClientDataset.padded_batch() but operates over the entire federated dataset.

# We select first 16 batches using itertools.islice.
batched_test_data = list(itertools.islice(
    fedjax.padded_batch_federated_data(test, batch_size=128), 16))
batched_train_data = list(itertools.islice(
    fedjax.padded_batch_federated_data(train, batch_size=128), 16))

print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.01757812, dtype=float32), 'loss': DeviceArray(4.1253214, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.02490234, dtype=float32), 'loss': DeviceArray(4.116228, dtype=float32)}

How does our model know what evaluation metrics to report? It is simply specified in the eval_metrics field. We will discuss evaluation metrics in more detail later.

model.eval_metrics
{'accuracy': Accuracy(target_key='y', pred_key=None),
 'loss': CrossEntropyLoss(target_key='y', pred_key=None)}

Since fedjax.evaluate_model() simply takes a stream of batches, we can also use it to evaluate multiple clients.

for client_id, dataset in itertools.islice(test.clients(), 4):
  print(
      client_id,
      fedjax.evaluate_model(model, params,
                            dataset.padded_batch(batch_size=128)))
b'002d084c082b8586:f0185_23' {'accuracy': DeviceArray(0.05, dtype=float32), 'loss': DeviceArray(4.1247168, dtype=float32)}
b'005fdad281234bc0:f0151_02' {'accuracy': DeviceArray(0.09375, dtype=float32), 'loss': DeviceArray(4.093891, dtype=float32)}
b'014c177da5b15a39:f1565_04' {'accuracy': DeviceArray(0., dtype=float32), 'loss': DeviceArray(4.127692, dtype=float32)}
b'0156df0c34a25944:f3772_10' {'accuracy': DeviceArray(0.05263158, dtype=float32), 'loss': DeviceArray(4.1521378, dtype=float32)}

The training objective

To train our model, we need two things: the objective function to minimize and an optimizer.

fedjax.Model contains two functions that can be used to arrive at the training objective:

  • apply_for_train(params, batch_example, rng) takes the current model parameters, a batch of examples, and a PRNGKey, and returns some output.

  • train_loss(batch_example, train_output) translates the output of apply_for_train() into a vector of per-example loss values.

In our example model, apply_for_train() produces a score for each class and train_loss() is simply the cross entropy loss. apply_for_train() in this case does not make use of a PRNGKey, so we can pass None instead for convenience. A different apply_for_train() might actually make use of the PRNGKey, for tasks such as dropout.

# train_batches is an infinite stream of shuffled batches of examples.
def train_batches():
  return fedjax.shuffle_repeat_batch_federated_data(
      train,
      batch_size=8,
      client_buffer_size=16,
      example_buffer_size=1024,
      seed=0)


# We obtain the first batch by using the `next` function.
example = next(train_batches())
output = model.apply_for_train(params, example, None)
per_example_loss = model.train_loss(example, output)

output.shape, per_example_loss
((8, 62), DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
              4.209843 , 4.060939 , 4.19899  ], dtype=float32))

Note that the output is per example predictions and has shape (8, 62), where 8 is the batch size and 62 is the number of classes. Alternatively, we can use model_per_example_loss() to get a function that gives us the same result. model_per_example_loss() is a convenience function that does exactly what we just did.

per_example_loss_fn = fedjax.model_per_example_loss(model)
per_example_loss_fn(params, example, None)
DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
             4.209843 , 4.060939 , 4.19899  ], dtype=float32)

The training objective is a scalar, so why does train_loss() return a vector of per-example loss values? First of all, the training objective in most cases is just the average of the per-example loss values, so arriving at the final training objective isn’t hard. Moreover, in certain algorithms, we not only use the train loss over a single batch of examples for a stochastic training step, but also need to estimate the average train loss over an entire (client) dataset. Having the per-example loss values there is instrumental in obtaining the correct estimate when the batch sizes may vary.

def train_objective(params, example):
  return jnp.mean(per_example_loss_fn(params, example, None))

train_objective(params, example)
DeviceArray(4.0680556, dtype=float32)

Optimizers

With the training objective at hand, we just need an optimizer to find some good model parameters that minimize it.

There are many optimizer implementations in JAX out there, but FedJAX doesn’t force one choice over any other. Instead, FedJAX provides a simple fedjax.optimizers.Optimizer interface so a new optimizer implementation can be wrapped. For convenience, FedJAX provides some common optimizers wrapped from optax.

optimizer = fedjax.optimizers.adam(1e-3)

An optimizer is simply a pair of two functions:

  • init(params) returns the initial optimizer state, such as initial values for accumulators of gradients.

  • apply(grads, opt_state, params) applies the gradients to update the current optimizer state and model parameters.

Instead of modifying opt_state or params, apply() returns a new pair of optimizer state and model parameters. In JAX, it is common to express computations in this stateless/mutation free style, often referred to as functional programming, or pure functions. The pureness of functions is crucial to many features in JAX, so it is always good practice to write functions that do not modify its inputs. You have probably also noticed that all the functions of fedjax.Model we have seen so far do not modify the model object itself (for example, init() returns model parameters instead of setting some attribute of model; apply_for_train() takes model parameters as an input argument, instead of getting it from model). FedJAX does this to keep all functions pure.

However, in the top level training loop, it is fine to mutate states since we are not in a function that may be transformed by JAX. Let’s run our first training step, which resulted in a slight decrease in objective on the same batch of examples.

To obtain the gradients, we use jax.grad() which returns the gradient function. More details about jax.grad() can be found from the JAX documentation.

opt_state = optimizer.init(params)
grads = jax.grad(train_objective)(params, example)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(4.0080366, dtype=float32)

Instead of using jax.grad() directly, FedJAX also provides a convenient fedjax.model_grad() which computes the gradient of a model with respect to the averaged fedjax.model_per_example_loss().

model_grads = fedjax.model_grad(model)(params, example, None)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(3.9482572, dtype=float32)

Let’s wrap everything into a single JIT compiled function and train a few more steps, and evaluate again.

@jax.jit
def train_step(example, opt_state, params):
  grads = jax.grad(train_objective)(params, example)
  return optimizer.apply(grads, opt_state, params)

for example in itertools.islice(train_batches(), 5000):
  opt_state, params = train_step(example, opt_state, params)

print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.6152344, dtype=float32), 'loss': DeviceArray(1.5562292, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.59765625, dtype=float32), 'loss': DeviceArray(1.6278805, dtype=float32)}

Building a custom model

fedjax.Model was designed with customization in mind. We have already seen how to switch to a different training loss. In this section, we will discuss how the rest of a fedjax.Model can be customized.

Training loss

Because train_loss() is separate from apply_for_train(), it is easy to switch to a different loss function.

def hinge_loss(example, output):
  label = example['y']
  num_classes = output.shape[-1]
  mask = jax.nn.one_hot(label, num_classes)
  label_score = jnp.sum(output * mask, axis=-1)
  best_score = jnp.max(output + 1 - mask, axis=-1)
  return best_score - label_score


hinge_model = model.replace(train_loss=hinge_loss)
fedjax.model_per_example_loss(hinge_model)(params, example, None)
DeviceArray([4.306656  , 0.        , 0.        , 0.4375435 , 0.96986485,
             0.        , 0.3052401 , 1.3918507 ], dtype=float32)

Evaluation metrics

We have already seen that the eval_metrics field of a fedjax.Model tells the model what metrics to evaluate. eval_metrics is a mapping from metric names to fedjax.metrics.Metric objects. A fedjax.metrics.Metric object tells us how to calculate a metric’s value from multiple batches of examples. Like fedjax.Model, a fedjax.metrics.Metric is stateless.

To customize the metrics to evaluate on, or what names to give to each, simply specify a different mapping.

only_accuracy = model.replace(
    eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
fedjax.evaluate_model(only_accuracy, params, batched_test_data)
{'accuracy': DeviceArray(0.6152344, dtype=float32)}

There are already some concrete Metrics in fedjax.metrics. It is also easy to implement a new one. You can read more about how to implement a Metric in its own introduction.

The bit of fedjax.Model that is directly relevant to evaluation is apply_for_eval(). The relation between apply_for_eval() and an evaluation metric is similar to that between apply_for_train() and train_loss(): apply_for_eval(params, example) takes the model parameters and a batch of examples (notice there is no randomness in evaluation so we don’t need a PRNGKey), and produces some prediction that evaluation metrics can consume. In our example, the outputs from apply_for_eval() and apply_for_train() are identical, but they don’t have to be.

jnp.all(
    model.apply_for_train(params, example, None) == model.apply_for_eval(
        params, example))
DeviceArray(True, dtype=bool)

What apply_for_eval() needs to produce really just depends on what evaluation fedjax.metrics.Metrics will be used. In our case, we are using fedjax.metrics.Accuracy, and fedjax.metrics.CrossEntropyLoss. They are similar in their requirements on the inputs:

  • They both need to know the true label from the example, using a target_key that defaults to "y".

  • They both need to know the predicted scores from apply_for_eval(), customizable as pred_key. If pred_key is None, apply_for_eval() should return just a vector of per-class scores; otherwise pred_key can be a string key, and apply_for_eval() should return a mapping (e.g. dict) that maps the key to a vector of per-class scores.

fedjax.metrics.Accuracy()
Accuracy(target_key='y', pred_key=None)

Neural network architectures

We have now covered all five parts of a fedjax.Model, namely init(), apply_for_train(), apply_for_eval(), train_loss(), and eval_metrics. train_loss() and eval_metrics are easy to customize since they are mostly agnostic to the actual neural network architecture of the model. init(), apply_for_train(), and apply_for_eval() on the other hand, are closely related.

In principle, as long as these three functions meet the interface we have seen so far, they can be used to build a custom model. Let’s try to build a model that uses multi-layer perceptron and hinge loss.

def cross_entropy_loss(example, output):
  label = example['y']
  num_classes = output.shape[-1]
  mask = jax.nn.one_hot(label, num_classes)
  return -jnp.sum(jax.nn.log_softmax(output) * mask, axis=-1)

def mlp_model(num_input_units, num_units, num_classes):

  def mlp_init(rng):
    w0_rng, w1_rng = jax.random.split(rng)
    w0 = jax.random.uniform(w0_rng, [num_input_units, num_units])
    b0 = jnp.zeros([num_units])
    w1 = jax.random.uniform(w1_rng, [num_units, num_classes])
    b1 = jnp.zeros([num_classes])
    return w0, b0, w1, b1

  def mlp_apply(params, batch, rng=None):
    w0, b0, w1, b1 = params
    x = batch['x']
    batch_size = x.shape[0]
    h = jax.nn.relu(x.reshape([batch_size, -1]) @ w0 + b0)
    return h @ w1 + b1

  return fedjax.Model(
      init=mlp_init,
      apply_for_train=mlp_apply,
      apply_for_eval=mlp_apply,
      train_loss=cross_entropy_loss,
      eval_metrics={'accuracy': fedjax.metrics.Accuracy()})


# There are 28*28 input pixels, and 62 classes in EMNIST.
mlp = mlp_model(28 * 28, 128, 62)

@jax.jit
def mlp_train_step(example, opt_state, params):

  @jax.grad
  def grad_fn(params, example):
    return jnp.mean(fedjax.model_per_example_loss(mlp)(params, example, None))

  grads = grad_fn(params, example)
  return optimizer.apply(grads, opt_state, params)


params = mlp.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
      fedjax.evaluate_model(mlp, params, batched_test_data))
for example in itertools.islice(train_batches(), 5000):
  opt_state, params = mlp_train_step(example, opt_state, params)
print('eval_test after training:',
      fedjax.evaluate_model(mlp, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.05078125, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.4951172, dtype=float32)}

While writing custom neural network architectures from scratch is possible, most of the time, it is much more convenient to use a neural network library such as Haiku or jax.example_libraries.stax. The two functions fedjax.create_model_from_haiku and fedjax.create_model_from_stax can convert a neural network expressed in the respective framework into a fedjax.Model. Let’s build a convolutional network using jax.example_libraries.stax this time.

def stax_cnn_model(input_shape, num_classes):
  stax_init, stax_apply = stax.serial(
      stax.Conv(
          out_chan=64, filter_shape=(3, 3), strides=(1, 1), padding='SAME'),
      stax.Relu,
      stax.Flatten,
      stax.Dense(256),
      stax.Relu,
      stax.Dense(num_classes),
  )
  return fedjax.create_model_from_stax(
      stax_init=stax_init,
      stax_apply=stax_apply,
      sample_shape=input_shape,
      train_loss=cross_entropy_loss,
      eval_metrics={'accuracy': fedjax.metrics.Accuracy()})


stax_cnn = stax_cnn_model([-1, 28, 28, 1], 62)


@jax.jit
def stax_cnn_train_step(example, opt_state, params):

  @jax.grad
  def grad_fn(params, example):
    return jnp.mean(
        fedjax.model_per_example_loss(stax_cnn)(params, example, None))

  grads = grad_fn(params, example)
  return optimizer.apply(grads, opt_state, params)


params = stax_cnn.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
      fedjax.evaluate_model(stax_cnn, params, batched_test_data))
for example in itertools.islice(train_batches(), 1000):
  opt_state, params = stax_cnn_train_step(example, opt_state, params)
print('eval_test after training:',
      fedjax.evaluate_model(stax_cnn, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.03076172, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.72558594, dtype=float32)}

Recap

In this chapter, we have covered the following:

  • Components of fedjax.Model: init(), apply_for_train(), apply_for_eval(), train_loss(), and eval_metrics.

  • Optimizers in fedjax.optimizers.

  • Standard centralized learning with a fedjax.Model.

  • Specifying evaluation metrics in eval_metrics.

  • Building a custom fedjax.Model.

Federated learning algorithms

Open in Colab

This tutorial introduces algorithms for federated learning in FedJAX. By completing this tutorial, we’ll learn how to write clear and efficient algorithms that follow best practices. This tutorial assumes that we have finished the tutorials on datasets and models.

In order to keep the code pseudo-code-like, we avoid using jax primitives directly while writing algorithms, with the notable exceptions of the jax.randomand jax.tree_util libraries. Since lower-level functions that are described in the model tutorial, such as fedjax.optimizers, model.grad, are all JIT compiled already, the algorithms will still be efficient.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
import jax
import jax.numpy as jnp
import numpy as np

import fedjax
# We only use TensorFlow for datasets, so we restrict it to CPU only to avoid
# issues with certain ops not being available on GPU/TPU.
fedjax.training.set_tf_cpu_only()

Introduction

A federated algorithm trains a machine learning model over decentralized data distributed over several clients. At a high level, the server first randomly initializes the model parameters and other learning components. Then at each round the following happens:

  1. Client selection: The server selects a few clients at each round, typically at random.

  2. The server transmits the model parameters and other necessary components to the selected clients.

  3. Client update: The clients update the model parameters using a subroutine, which typically involves a few epochs of SGD on their local examples.

  4. The clients transmit the updates to the server.

  5. Server aggregation: The server combines the clients’ updates to produce new model parameters.

A pseudo-code for a common federated learning algorithm can be found in Algorithm 1 in Kairouz et al. (2020). Since FedJAX focuses on federated simulation and there is no actual transmission between clients and the server, we only focus on steps 1, 3, and 5, and ignore steps 2 and 4. Before we describe each of the modules, we will first describe how to use algorithms that are implemented in FedJAX.

Federated algorithm overview

We implement federated learning algorithms using the fedjax.FederatedAlgorithm interface. The fedjax.FederatedAlgorithm interface has two functions init and apply. Broadly, our implementation has three parts.

  1. ServerState: This contains all the information available at the server at any given round. It includes model parameters and can also include other parameters that are used during optimization. At every round, a subset of ServerState is passed to the clients for federated learning. ServerState is also used in checkpointing and evaluation. Hence it is crucial that all the parameters that are modified during the course of federated learning are stored as part of the ServerState. Do not store mutable parameters as part of fedjax.FederatedAlgorithm.

  2. init: Initializes the server state.

  3. apply: Takes the ServerState and a set of client_ids, corresponding datasets, and random keys and returns a new ServerState along with any information we need from the clients in the form of client_diagnostics.

We demonstrate fedjax.FederatedAlgorithm using Federated Averaging (FedAvg) and the emnist dataset. We first initialize the model, datasets and the federated algorithm.

train, test = fedjax.datasets.emnist.load_data(only_digits=False)
model = fedjax.models.emnist.create_conv_model(only_digits=False)

rng = jax.random.PRNGKey(0)
init_params = model.init(rng)
# Federated algorithm requires a gradient function, client optimizer, 
# server optimizers, and hyperparameters for batching at the client level.
grad_fn = fedjax.model_grad(model)
client_optimizer = fedjax.optimizers.sgd(0.1)
server_optimizer = fedjax.optimizers.sgd(1.0)
batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=10)
fed_alg = fedjax.algorithms.fed_avg.federated_averaging(grad_fn,
                                                        client_optimizer,
                                                        server_optimizer,
                                                        batch_hparams)

Note that similar to the rest of the library, we only pass on the necessary functions and parameters to the federated algorithm. Hence, to initialize the federated algorithm, we only passed the grad_fn and did not pass the entire model. With this, we now initialize the server state.

init_server_state = fed_alg.init(init_params)

To run the federated algorithm, we pass the server state and client data to the apply function. For this end, we pass client data as a tuple of client id, client data, and the random keys. Adding client ids and random keys has multiple advantages. Firstly client client ids allows to track client diagnostics and would be helpful in debugging. Passing random keys would ensure deterministic execution and allow repeatability. Furthermore, as we discuss later it would help us with fast implementations. We first format the data in this necessary format and then run one round of federated learning.

# Select 5 client_ids and their data
client_ids = list(train.client_ids())[:5]
clients_ids_and_data = list(train.get_clients(client_ids))

client_inputs = []
for i in range(5):
  rng, use_rng = jax.random.split(rng)
  client_id, client_data = clients_ids_and_data[i]
  client_inputs.append((client_id, client_data, use_rng))

updated_server_state, client_diagnostics = fed_alg.apply(init_server_state,
                                                         client_inputs)
# Prints the l2 norm of gradients as part of client_diagnostics. 
print(client_diagnostics)
{b'002d084c082b8586:f0185_23': {'delta_l2_norm': DeviceArray(1.9278834, dtype=float32)}, b'005fdad281234bc0:f0151_02': {'delta_l2_norm': DeviceArray(1.8239512, dtype=float32)}, b'014c177da5b15a39:f1565_04': {'delta_l2_norm': DeviceArray(1.6514685, dtype=float32)}, b'0156df0c34a25944:f3772_10': {'delta_l2_norm': DeviceArray(1.5863262, dtype=float32)}, b'01725f8a648ceeb6:f3408_47': {'delta_l2_norm': DeviceArray(1.613201, dtype=float32)}}

As we see above, the client statistics provide the delta_l2_norm of the gradients for each client, which can be potentially used for debugging purposes.

Writing federated algorithms

With this background on how to use existing implementations, we are now going to describe how to write your own federated algorithms in FedJAX. As discussed above, this involves three steps:

  1. Client selection

  2. Client update

  3. Server aggregation

Client selection

At each round of federated learning, typically clients are sampled uniformly at random. This can be done using numpy as follows.

all_client_ids = list(train.client_ids())
print("Total number of client ids: ", len(all_client_ids))

sampled_client_ids = np.random.choice(all_client_ids, size=2, replace=False)
print("Sampled client ids: ", sampled_client_ids)
Total number of client ids:  3400
Sampled client ids:  [b'3abf53413107a36f:f3901_45' b'7b15a2a43e2c5097:f0452_37']

However, the above code is not desirable due to the following reasons:

  1. For reproducibility, it is desirable to have a fixed seed just for sampling clients.

  2. Across rounds, different clients need to be sampled.

  3. For I/O efficiency reasons, it might be better to do an approximately uniform sampling, where clients whose data is stored together are sampled together.

  4. Federated algorithms typically require additional randomness for batching, or dropout that needs to be sent to clients.

To incorporate these features, FedJAX provides a few client samplers.

  1. fedjax.client_samplers.UniformShuffledClientSampler

  2. fedjax.client_samplers.UniformGetClientSampler

fedjax.client_samplers.UniformShuffledClientSampler is preferred for efficiency reasons, but if we need to sample clients truly randomly, fedjax.client_samplers.UniformGetClientSampler can be used. Both of them have a sample function that returns a list of client_ids, client_data, and client_rng.

efficient_sampler = fedjax.client_samplers.UniformShuffledClientSampler(
    train.shuffled_clients(buffer_size=100), num_clients=2)
print("Sampling from the efficient sampler.")
for round in range(3):
  sampled_clients_with_data = efficient_sampler.sample()
  for client_id, client_data, client_rng in sampled_clients_with_data:
    print(round, client_id)

perfect_uniform_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
print("Sampling from the perfect uniform sampler.")
for round in range(3):
  sampled_clients_with_data = perfect_uniform_sampler.sample()
  for client_id, client_data, client_rng in sampled_clients_with_data:
    print(round, client_id)
Sampling from the efficient sampler.
0 b'1ec29e39b10521aa:f3848_28'
0 b'0156df0c34a25944:f3772_10'
1 b'0e9e55d97351b8a6:f1424_45'
1 b'1cb4c2b15c501e91:f0936_34'
2 b'1ef43f1933de69ae:f2086_25'
2 b'0a544c86e9731fc1:f0629_39'
Sampling from the perfect uniform sampler.
0 b'0eb0b8eb6ab0fcd6:f0174_06'
0 b'2b766db84ee57ba6:f2380_76'
1 b'b268b64fea71f3a7:f0766_01'
1 b'786d98ea95e5a59b:f1729_47'
2 b'35b6c7951d56d353:f3574_30'
2 b'98355b0cd97cdbb8:f3990_02'

Client update

After selecting the clients, the next step would be running a model update step in the clients. Typically this is done by running a few epochs of SGD. We only pass parts of the algorithm that are necessary for the client update.

The client update typically requires a set of parameters from the server (init_params in this example), the client dataset, and a source of randomness (rng). The randomness can be used for dropout or other model update steps. Finally, instead of passing the entire model to the client update, since our code only depends on the gradient function, we pass grad_fn to client_update.

def client_update(init_params, client_dataset, client_rng, grad_fn):
  opt_state = client_optimizer.init(init_params)
  params = init_params
  for batch in client_dataset.shuffle_repeat_batch(batch_size=10):
    client_rng, use_rng = jax.random.split(client_rng)
    grads = grad_fn(params, batch, use_rng)
    opt_state, params = client_optimizer.apply(grads, opt_state, params)
  delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
                                             init_params, params)
  return delta_params, len(client_dataset)

client_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
for client_id, client_data, client_rng in sampled_clients_with_data:
  delta_params, num_samples = client_update(init_params,client_data, 
                                            client_rng, grad_fn)
  print(client_id, num_samples, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' 348 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' 152 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Server aggregation

The outputs of the clients are typically aggregated by computing the weighted mean of the updates, where the weight is the number of client examples. This can be easily done by using the fedjax.tree_util.tree_mean function.

sampled_clients_with_data = client_sampler.sample()
client_updates = []
for client_id, client_data, client_rng in sampled_clients_with_data:
  delta_params, num_samples = client_update(init_params, client_data,
                                            client_rng, grad_fn)
  client_updates.append((delta_params, num_samples))
updated_output = fedjax.tree_util.tree_mean(client_updates)
print(updated_output.keys())
KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Combing the above steps gives the FedAvg algorithm, which can be found in the example FedJAX implementation of FedAvg..

Efficient implementation

The above implementation would be efficient enough for running on single machines. However, JAX provides primitives such as jax.pmap and jax.vmap for efficient parallelization across multiple accelerators. FedJAX provides support for them in federated learning by distributing client computation across several accelerators.

To take advantage of the faster implementation, we need to implement client_update in a specific format. It has three functions:

  1. client_init

  2. client_step

  3. client_final

client_init

This function takes the inputs from the server and outputs a client_step_state which will be passed in between client steps. It is desirable for the client_step_state to be a dictionary. In this example, it just copies the parameters, optimizer_state and the current state of client randomness.

We can think of the inputs from the server as “shared inputs” that are shared across all clients and the client_step_state as client-specific inputs that are separate per client.

def client_init(server_params, client_rng):
  opt_state = client_optimizer.init(server_params)
  client_step_state = {
      'params': server_params,
      'opt_state': opt_state,
      'rng': client_rng,
  }
  return client_step_state

client_step

client_step takes the current client_step_state and a batch of examples and updates the client_step_state. In this example, we run one step of SGD using the batch of examples and update client_step_state to reflect the new parameters, optimization state, and randomness.

def client_step(client_step_state, batch):
  rng, use_rng = jax.random.split(client_step_state['rng'])
  grads = grad_fn(client_step_state['params'], batch, use_rng)
  opt_state, params = client_optimizer.apply(grads,
                                             client_step_state['opt_state'],
                                             client_step_state['params'])
  next_client_step_state = {
      'params': params,
      'opt_state': opt_state,
      'rng': rng,
  }
  return next_client_step_state

client_final

client_final modifies the final client_step_state and returns the desired parameters. In this example, we compute the difference between the initial parameters and the final updated parameters in the client_final function.

def client_final(server_params, client_step_state):
  delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
                                             server_params,
                                             client_step_state['params'])
  return delta_params

fedjax.for_each_client

Once we have these three functions, we can combine them to create a client_update function using the fedjax.for_each_client function. fedjax.for_each_client returns a function that can be used to run client updates. The sample usage is below.

for_each_client_update = fedjax.for_each_client(client_init,
                                                client_step,
                                                client_final)

client_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
batched_clients_data = [
      (cid, cds.shuffle_repeat_batch(batch_size=10), crng)
      for cid, cds, crng in sampled_clients_with_data
  ]
for client_id, delta_params in for_each_client_update(init_params,
                                                      batched_clients_data):
  print(client_id, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Note that for_each_client_update requires the client data to be already batched. This is necessary for performance gains while using multiple accelerators. Furthermore, the batch size needs to be the same across all clients.

By default fedjax.for_each_client selects the standard JIT backend. To enable parallelism with TPUs or for debugging, we can set it using fedjax.set_for_each_client_backend(backend), where backend is either ‘pmap’ or ‘debug’, respectively.

The for each client function can also be used to add some additional step wise results, which can be used for debugging. This requires changing the client_step function.

def client_step_with_log(client_step_state, batch):
  rng, use_rng = jax.random.split(client_step_state['rng'])
  grads = grad_fn(client_step_state['params'], batch, use_rng)
  opt_state, params = client_optimizer.apply(grads,
                                             client_step_state['opt_state'],
                                             client_step_state['params'])
  next_client_step_state = {
      'params': params,
      'opt_state': opt_state,
      'rng': rng,
  }
  grad_norm = fedjax.tree_util.tree_l2_norm(grads)
  return next_client_step_state, grad_norm


for_each_client_update = fedjax.for_each_client(
    client_init, client_step_with_log, client_final, with_step_result=True)

for client_id, delta_params, grad_norms in for_each_client_update(
    init_params, batched_clients_data):

  print(client_id, list(delta_params.keys()))
  print(client_id, np.array(grad_norms))
b'0eb0b8eb6ab0fcd6:f0174_06' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'0eb0b8eb6ab0fcd6:f0174_06' [4.599997  3.9414525 4.8312078 4.450683  5.971922  2.4694602 2.6183944
 2.1996734 2.7145145 2.9750984 3.0633514 2.4050198 2.612233  2.672571
 3.1303792 3.236007  3.3968801 2.986587  2.5775976 2.8625555 3.1062818
 4.4250994 2.7431202 3.2192783 2.7670481 3.6075711 3.7296255 5.190155
 3.4366677 4.5394745 3.2277424 3.1362765 2.8626535 3.7905648 3.5686817]
b'2b766db84ee57ba6:f2380_76' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'2b766db84ee57ba6:f2380_76' [4.3958793 4.4705005 4.84529   4.4623365 5.3809257 5.4818144 3.0325508
 4.1576953 8.666909  2.067883  2.1935403 2.5095372 2.2202325 2.8493588
 3.2463503 3.6446893]

Recap

In this tutorial, we have covered the following:

  • Using exisiting algorithms in fedjax.algorithms.

  • Writing new algorithms using fedjax.FederatedAlgorithm.

  • Efficient implementation using fedjax.for_each_client in the presence of accelerators.

Contributing

Everyone can contribute to FedJAX, and we value everyone’s contributions. There are several ways to contribute, including:

The FedJAX project follows Google’s Open Source Community Guidelines.

Ways to contribute

We welcome pull requests, in particular for those issues marked with contributions welcome or good first issue.

For other proposals, we ask that you first open a GitHub Issue or Discussion to seek feedback on your planned contribution.

Contributing code using pull requests

We do all of our development using git, so basic knowledge is assumed.

Follow these steps to contribute code:

  1. Fork the FedJAX repository by clicking the Fork button on the repository page. This creates a copy of the FedJAX repository in your own account.

  2. Install a support version Python listed in https://github.com/google/fedjax/blob/main/setup.py.

  3. pip installing your fork from source. This allows you to modify the code and immediately test it out:

    git clone https://github.com/YOUR_USERNAME/fedjax
    cd fedjax
    pip install -r requirements-test.txt # Installs all testing requirements.
    pip install -e .  # Installs FedJAX from the current directory in editable mode.
    
  4. Add the FedJAX repo as an upstream remote, so you can use it to sync your changes.

    git remote add upstream http://www.github.com/google/fedjax
    
  5. Create a branch where you will develop from:

    git checkout -b name-of-change
    

    And implement your changes using your favorite editor. If you are adding a new algorithm, please add unit tests and an associated binary in the experiments folder. The binary should use the EMNIST dataset with a convolution neural model example and have reasonable default hyperparameters that ideally reproduce results from a published paper. We strongly recommend using fedjax.for_each_client in your algorithm implementations for computational efficiency.

  6. Make sure the tests pass by running the following command from the top of the repository:

pytest -n auto -q \
  -k "not SubsetFederatedDataTest and not SQLiteFederatedDataTest and not ForEachClientPmapTest and not DownloadsTest and not CheckpointTest and not LoggingTest" \
  fedjax --ignore=fedjax/legacy/

-q will reduce verbosity levels, -k selects/deselects specific tests, and --ignore=fedjax/legacy/ is used to skip the entire fedjax.legacy module. If there are errors or failures, you can run those specific tests using the commands in the next section to see more focused details.

pytest -n auto tests/

If you know the specific test file that covers your changes, you can limit the tests to that; for example:

# Run all tests in algorithms
pytest -n auto fedjax/algorithms

# Run only fedjax/core/metrics_test.py
pytest -n auto fedjax/core/metrics_test.py
  1. Once you are satisfied with your change, create a commit as follows ( how to write a commit message):

    git add file1.py file2.py ...
    git commit -m "Your commit message"
    

    Then sync your code with the main repo:

    git fetch upstream
    git rebase upstream/main
    

    Finally, push your commit on your development branch and create a remote branch in your fork that you can use to create a pull request from:

    git push --set-upstream origin name-of-change
    
  2. Create a pull request from the FedJAX repository and send it for review. Check the FedJAX pull request checklist for considerations when preparing your PR, and consult GitHub Help if you need more information on using pull requests.

FedJAX pull request checklist

As you prepare a FedJAX pull request, here are a few things to keep in mind:

Google contributor license agreement

Contributions to this project must be accompanied by a Google Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to https://cla.developers.google.com/ to see your current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you’ve already submitted one (even if it was for a different project), you probably don’t need to do it again. If you’re not certain whether you’ve signed a CLA, you can open your PR and our friendly CI bot will check for you.

Single-change commits and pull requests

A git commit ought to be a self-contained, single change with a descriptive message. This helps with review and with identifying or reverting changes if issues are uncovered later on.

Pull requests typically comprise a single git commit. (In some cases, for instance for large refactors or internal rewrites, they may contain several.) In preparing a pull request for review, you may need to squash together multiple commits. We ask that you do this prior to sending the PR for review if possible. The git rebase -i command might be useful to this end.

Linting and Type-checking

Please follow the style guide and check code quality by pylint as stated here https://google.github.io/styleguide/pyguide.html

Full GitHub test suite

Your PR will automatically be run through a full test suite on GitHub CI, which covers a range of Python versions, dependency versions, and configuration options. It’s normal for these tests to turn up failures that you didn’t catch locally; to fix the issues you can push new commits to your branch.

Restricted test suite

Once your PR has been reviewed, a FedJAX maintainer will mark it as Pull Ready. This will trigger a larger set of tests, including tests on GPU and TPU backends that are not available via standard GitHub CI. Detailed results of these tests are not publicly viweable, but the FedJAX mantainer assigned to your PR will communicate with you regarding any failures these might uncover; it’s not uncommon, for example, that numerical tests need different tolerances on TPU than on CPU.

Building from source

First, clone the FedJAX source code:

git clone https://github.com/google/fedjax
cd fedjax

Then install the fedjax Python package:

pip install -e .

To upgrade to the latest version from GitHub, inside of the repository root, just run

git pull

You shouldn’t have to reinstall fedjax because pip install -e sets up symbolic links from site-packages into the repository.

Running the tests

We created a simple run_tests.sh script for running tests. See its comments for examples. Before creating a pull request, we recommend running all the FedJAX tests (i.e. running run_tests.sh with no arguments) to verify the correctness of a change.

Updating the docs

Install the requirements

pip install -r docs/requirements.txt

Then run

sphinx-autobuild -b html --watch . --open-browser docs docs/_build/html

sphinx-autobuild will watch for file changes and auto build the HTML for you, so all you’ll have to do is refresh the page. If you don’t want to use the auto builder, you can just use:

sphinx-build -b html docs docs/_build/html

and then navigate to docs/_build/html/index.html in your browser.

How to write code documentation

Our documentation it is written in ReStructuredText for Sphinx. This is a meta-language that is compiled into online documentation. For more details see Sphinx’s documentation.

We also rely heavily on sphinx.ext.autodoc to convert docstrings in source to rst for Sphinx, so it would be best to be familiar with its directives.

As a result, our docstrings adhere to a specific syntax that has to be kept in mind. Below we provide some guidelines.

How to use “code font”

When writing code font in a docstring, please use double backticks.

# This returns a ``str`` object.

How to write math

We’re using sphinx.ext.mathjax. Then we use the math directive and role to either inline or block notation.

# Blocked notation
# .. math::
#		x + y

# Inline notation :math:`x + y`

Examples

Take a look at docs/fedjax.metrics.rst to get a good idea of what to do.

Update notebooks

It is easiest to edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface, open http://colab.research.google.com and Upload ipynb from your local repo. Update it as needed, Run all cells then Download ipynb to your local repo. You may want to test that it executes properly, using sphinx-build as explained above. We recommend making changes this way to avoid introducing format errors into the .ipynb files.

In the future, we may build and re-execute the notebooks as part of the Read the docs build. However, for now, we exclude all notebooks from the build due to long durations (downloading dataset files, expensive model training, etc.). See exclude_patterns in conf.py

fedjax core

FedJAX API.

Subpackages

fedjax.metrics

A small library for working with evaluation metrics such as accuracy.

Stats

fedjax.metrics.Stat

Stat keeps some statistic, along with operations over them.

fedjax.metrics.MeanStat

Statistic for weighted mean calculation.

fedjax.metrics.SumStat

Statistic for summing values.

Metrics

fedjax.metrics.Metric

Metric is the conceptual metric (like accuracy).

fedjax.metrics.CrossEntropyLoss

Metric for cross entropy loss.

fedjax.metrics.Accuracy

Metric for accuracy.

fedjax.metrics.TopKAccuracy

Metric for top k accuracy.

fedjax.metrics.SequenceTokenCrossEntropyLoss

Metric for token cross entropy loss for a sequence example.

fedjax.metrics.SequenceCrossEntropyLoss

Metric for total cross entropy loss for a sequence example.

fedjax.metrics.SequenceTokenAccuracy

Metric for token accuracy for a sequence example.

fedjax.metrics.SequenceTokenTopKAccuracy

Metric for token top k accuracy for a sequence example.

fedjax.metrics.SequenceTokenCount

Metric for count of non masked tokens for a sequence example.

fedjax.metrics.SequenceCount

Metric for count of non masked sequences.

fedjax.metrics.SequenceTruncationRate

Metric for truncation rate for a sequence example.

fedjax.metrics.SequenceTokenOOVRate

Metric for out-of-vocabulary (OOV) rate for a sequence example.

fedjax.metrics.SequenceLength

Metric for length for a sequence example.

fedjax.metrics.PerDomainMetric

Turns a base metric into one that groups results by domain.

fedjax.metrics.ConfusionMatrix

Metric for making a Confusion Matrix.

Miscellaneous

fedjax.metrics.unreduced_cross_entropy_loss

Returns unreduced cross entropy loss.

fedjax.metrics.evaluate_batch

Evaluates a batch using a metric.

Quick Overview

To evaluate model predictions, use a Metric object such as Accuracy . We recommend fedjax.core.models.evaluate_model() in most scenarios, which runs model prediction, and evaluation, on batches of N examples at a time for greater computational efficiency:

# Mock out Model.
model = fedjax.Model(
    init=lambda _: None,  # Unused.
    apply_for_train=lambda _, _, _: None,  # Unused.
    apply_for_eval=lambda _, batch: batch.get('pred'),
    train_loss=lambda _, _: None,  # Unused.
    eval_metrics={'accuracy': metrics.Accuracy()})
params = None  # Unused.
batches = [{'y': np.array([1, 0]),
            'pred': np.array([[1.2, 0.4], [2.3, 0.1]])},
           {'y': np.array([1, 1]),
            'pred': np.array([[0.3, 3.2], [2.1, 4.3]])}]
results = fedjax.evaluate_model(model, params, batches)
print(results)
# {'accuracy': 0.75}

A Metric object has 2 methods:

  • zero() : Initial value for accumulating the statistic for this metric.

  • evaluate_example() : Returns the statistic from evaluating a single example, given the training example and the model prediction.

Most Metric follow the following convention for convenience:

  • example is a dict-like object from str to jnp.ndarray.

  • prediction is either a single jnp.ndarray, or a dict-like object from str to jnp.ndarray.

Conceptually, we can also use a simple for loop to evaluate a collection of examples and model predictions:

# By default, the `Accuracy` metric treats `example['y']` as the true label,
# and `prediction` as a single `jnp.ndarray` of class scores.
metric = Accuracy()
stat = metric.zero()
# We are iterating over individual examples, not batches.
for example, prediction in [({'y': jnp.array(1)}, jnp.array([0., 1.])),
                            ({'y': jnp.array(0)}, jnp.array([1., 0.])),
                            ({'y': jnp.array(1)}, jnp.array([1., 0.])),
                            ({'y': jnp.array(0)}, jnp.array([2., 0.]))]:
  stat = stat.merge(metric.evaluate_example(example, prediction))
print(stat.result())
# 0.75

In practice, for greater computational efficiency, we run model prediction not on a single example, but a batch of N examples at a time. fedjax.core.models.evaluate_model() provides a simple way to do so. Under the hood, it calls evaluate_batch()

metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
  ({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
  ({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
  stat = stat.merge(evaluate_batch(metric, batch_example, batch_prediction))
print(stat.result())
# 0.75
Under the hood

For most users, it is sufficient to know how to use existing Metric subclasses such as Accuracy with fedjax.core.models.evaluate_model() . This section is intended for those who would like to write new metrics.

From algebraic structures to Metric and Stat

There are 2 abstraction in this library, Metric and Stat . Before going into details of these classes, let’s first consider a few abstract properties related to evaluation metrics, using accuracy as an example.

When evaluating accuracy on a dataset, we wish to know the proportion of examples that are correctly predicted by the model. Because a dataset might be too large to fit into memory, we need to divide the work by partitioning the dataset into smaller subsets, evaluate each separately, and finally somehow combine the results. Assuming the subsets can be of different sizes, although the accuracy value is a single number, we cannot just average the accuracy values from each partition to arrive at the overall accuracy. Instead, we need 2 numbers from each subset:

  • The number of examples in this subset,

  • How many of them are correctly predicted.

We call these two numbers from each subset a statistic. The domain (the set of possible values) of the statistic in the case of accuracy is

\[\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}\]

With the numbers of examples and correct predictions from 2 disjoint subsets, we add the numbers up to get the number of examples and correct predictions for the union of the 2 subsets. We call this operation from 2 statistics into 1 a \(merge\) operation.

Let \(f(S)\) be the function that gives us the statistic from a subset of examples. It is easy to see for two disjoint subsets \(A\) and \(B\) , \(merge(f(A), f(B))\) should be equal to \(f(A ∪ B)\) . If no such \(merge\) exists, we cannot evaluate the dataset by partitioning the work. This requirement alone implies the domain of a statistic, and the \(merge\) operation forms a specific algebraic structure (a commutative monoid).

  • \(I := f(empty set)\) is one and the only identity element w.r.t.
    \(merge\) (i.e. \(merge(I, x) == merge(x, I) == x\) ).
  • \(merge()\) is commutative and associative.

Further, we can see \(f(S)\) can be defined just knowing two types of values:

  • \(f(empty set)\), i.e. \(I\) ;

  • \(f({x})\) for any single example \(x\) .

For any other subset \(S\) , we can derive the value of \(f(S)\) using these values and \(merge\) . Metric is simply the \(f(S)\) function above, defined in 2 corresponding parts:

On the other hand, Stat stores a single statistic, a merge() method for combining two, and a result() method for producing the final metric value.

To implement Accuracy as a subclass of Metric , we first need to know what Stat to use. In this case, the statistic domain and merge is implemented by a MeanStat . A MeanStat holds two values:

  • accum is the weighted sum of values, i.e. the number of correct
    predictions in the case of accuracy.
  • weight is the sum of weights, i.e. the number of examples in
    the case of accuracy.

merge() adds up the respective accum and weight from two MeanStat objects.

Sometimes, a new Stat subclass is necessary. In that case, it is very important to make sure the implementation has a clear definition of the domain, and the merge() operation adheres to the properties regarding identity element, commutativity, and associativity (e.g. if we unknowingly allow pairs of \((x, 0)\) for \(x != 0\) into the domain of a MeanStat , \(merge((x, 0), (a, b))\) will produce a statistic that leads to incorrect final metric values, i.e. \((a+x)/b\) , instead of \(a/b\) ).

Batching Stat s

In most cases, the final value of an evaluation is simply a scalar, and the corresponding statistic is also a tuple of a few scalar values. However, for the same reason why jax.vmap() is a lot more efficient than a for loop, it is a lot more efficient to store multiple Stat values as a Stat of arrays, instead of a list of Stat objects. Thus instead of a list [MeanStat(1, 2), MeanStat(3, 4), MeanStat(5, 6)] (call these “rank 0” Stat s), we can batch the 3 statitstics as a single MeanStat(jnp.array([1, 3, 5]), jnp.array([2, 4, 6])). A Stat object holding a single statistic is a “rank 0” Stat . A Stat object holding a vector of statistics is a “rank 1” Stat . Similarly, a Stat object may also hold a matrix, a 3D array, etc, of statistics. These are higher rank Stat s. In most cases, the rank 1 implementation of merge() and :meth`~Stat.result` automatically generalizes to higher ranks as elementwise operations.

In the end, we want just 1 final metric value instead of a length 3 vector, or a 2x2 matrix, of metric values. To finally go back to a single statistic (), we need to merge() statistics stored in these arrays. Each Stat subclass provides a reduce() method to do just that. The combination of jax.vmap() over Metric.evaluate_example(), and Stat.reduce() , is how we get an efficient evaluate_batch() function (of course, the real evaluate_batch() is jax.jit() ‘d so that the same jax.vmap() transformation etc does not need to happen over and over.:

metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
  ({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
  ({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
  # Get a batch of statistics as a single Stat object.
  batch_stat = jax.vmap(metric.evaluate_example)(batch_example,
  batch_prediction)
  # Merge the reduced single statistic onto the accumulator.
  stat = stat.merge(batch_stat.reduce())
print(stat.result())
# 0.75

Being able to batch Stat s also allows us to do other interesting things, for example,

  • evaluate_batch() accepts an optional per-example mask so it can work
    on padded batches.
  • We can define a PerDomainMetric metric for any base metric so that we can get
    accuracy where examples are partitioned by a domain id.
Creating a new Metric

Most likely, a new metric will just return a MeanStat or a SumStat. If that’s the case, simply follow the guidelines in Metric ‘s class docstring.

If a new Stat is necessary, follow the guidelines in Stat ‘s docstring.


class fedjax.metrics.Stat

Stat keeps some statistic, along with operations over them.

Most users will only need to interact with a Stat object via result()

For those who need to create new metrics, please first read the Under the hood section of the module docstring.

Most Stat’s domain (the set of possible statistic values) has constraints, it is thus usually a good practice to offer and use factory methods to construct new Stat objects instead of directly assigning the fields.

To work with various jax constructs, a concrete Stat should be a PyTree. This is easily achieved with fedjax.dataclass.

A Stat may hold either a single statistic (a rank 0 Stat), or an array of statistics (a higher rank Stat). result() and merge() only needs to work on a rank 0 Stat reduce() only needs to work on a higher rank Stat

abstract merge(other)

Merges two Stat objects into a new Stat with merged statistics.

Parameters:

other (Stat) – Another Stat object of the same type.

Return type:

Stat

Returns:

A new Stat object of the same type with merged statistics.

abstract reduce(axis=0)

Reduces a higher rank statistic along a given axis.

See the class docstring for details.

Parameters:

axis (Optional[int]) – An integer axis index, or None.

Return type:

Stat

Returns:

A new Stat object of the same type.

abstract result()

Calculates the metric value from the statistic value.

For example, MeanStat.result() calculates a weighted average.

Return type:

Array

Returns:

The return value must be a jnp.ndarray.

class fedjax.metrics.MeanStat(accum, weight)

Bases: Stat

Statistic for weighted mean calculation.

Prefer using the MeanStat.new() factory method instead of directly assigning to fields.

Example:

stat_0 = MeanStat.new(accum=1, weight=2)
stat_1 = MeanStat.new(accum=2, weight=3)
merged_stat = stat_0.merge(stat_1)
print(merged_stat)
# MeanState(accum=3, weight=5) => 0.6

stat = MeanStat.new(jnp.array([1, 2, 4]), jnp.array([1, 1, 0]))
reduced_stat = stat.reduce()
print(reduced_stat)
# MeanStat(accum=3, weight=2) => 1.5
accum

The weighted sum.

Type:

jax.Array

weight

The sum of weights.

Type:

jax.Array

classmethod new(accum, weight)

Creates a sanitized MeanStat.

The domain of a weighted mean statistic is:

\[\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}\]

new() sanitizes values outside the domain into the identity (zeros).

Parameters:
  • accum – A value convertible to jnp.ndarray.

  • weight – A value convertible to jnp.ndarray.

Return type:

MeanStat

Returns:

The sanitized MeanStat.

class fedjax.metrics.SumStat(accum)

Bases: Stat

Statistic for summing values.

Example:

stat_0 = SumStat.new(accum=1)
stat_1 = SumStat.new(accum=2)
merged_stat = stat_0.merge(stat_1)
print(merged_stat)
# SumStat(accum=3) => 3

stat = SumStat.new(jnp.array([1, 2, 1]))
reduced_stat = stat.reduce()
print(reduced_stat)
# SumStat(accum=4) => 4
accum

Sum of values.

Type:

jax.Array

classmethod new(accum)

Creates a sanitized SumStat.

Return type:

SumStat

class fedjax.metrics.Metric

Metric is the conceptual metric (like accuracy).

It defines two methods:

Given a Metric object m, let

  • u = m.zero()

  • v = m.evaluate_example(...)

We require that

  • type(u) == type(v).

  • u.merge(v) == v.merge(u) == v.

  • Components of u has the same shape as the counter parts in v.

abstract evaluate_example(example, prediction)

Evaluates a single example.

e.g. for accuracy: MeanStat.new(num_correct, num_total)

Parameters:
  • example (Mapping[str, Array]) – A single input example (e.g. one sentence for language).

  • prediction (Union[Array, Mapping[str, Array]]) – Output for example from fedjax.core.models.Model.apply_for_eval().

Return type:

Stat

Returns:

Stat value.

abstract zero()

Returns a Stat such that merging with it is an identity operation.

e.g. for accuracy: MeanStat.new(0., 0.)

Return type:

Stat

Returns:

Stat identity value.

class fedjax.metrics.CrossEntropyLoss(target_key='y', pred_key=None)

Bases: Metric

Metric for cross entropy loss.

Example:

example = {'y': jnp.array(1)}
prediction = jnp.array([1.2, 0.4])
metric = CrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.1711007, weight=1) => 1.1711007
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

class fedjax.metrics.Accuracy(target_key='y', pred_key=None)

Bases: Metric

Metric for accuracy.

Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0, 0, 1])
metric = Accuracy()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1, weight=1) => 1
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

class fedjax.metrics.TopKAccuracy(k, target_key='y', pred_key=None)

Bases: Metric

Metric for top k accuracy.

This metric computes the number of times where the correct class is among the top k classes predicted.

Example: top 3 accuracy

  • Dog => [Dog, Cat, Bird, Mouse, Penguin] ✓

  • Cat => [Bird, Mouse, Cat, Penguin, Dog] ✓

  • Dog => [Dog, Cat, Bird, Penguin, Mouse] ✓

  • Bird => [Bird, Cat, Mouse, Penguin, Dog] ✓

  • Cat => [Cat, Bird, Mouse, Dog, Penguin] ✓

  • Cat => [Cat, Mouse, Dog, Penguin, Bird] ✓

  • Mouse => [Penguin, Cat, Dog, Mouse, Bird] x

  • Penguin => [Dog, Mouse, Cat, Penguin, Bird] x

6 correct predictions in top 3 predicted classes / 8 total examples = .75 top 3 accuracy

Top k accuracy, also known as top n accuracy, is a useful metric when it comes to recommendations. One example would be the word recommendations on a virtual keyboard where three suggested words are displayed.

For k=1, we strongly recommend using Accuracy to avoid an unnecessary argsort. k < 1 will return 0. and k >= num_classes will return 1.

If two or more classes have the same prediction, the classes will be considered in order of lowest to highest indices.

Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0, 0.5, 0.2])
metric = TopKAccuracy(k=2)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1, weight=1) => 1
k

Number of top elements to look at for computing accuracy.

Type:

int

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

class fedjax.metrics.SequenceTokenCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,), per_position=False)

Bases: Metric

Metric for token cross entropy loss for a sequence example.

Example:

example = {'y': jnp.array([1, 0, 1])}
prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
metric = SequenceTokenCrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.2246635, weight=2) => 0.61233175

per_position_metric = SequenceTokenCrossEntropyLoss(per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1.1711007, 0., 0.05356275], weight=[1., 0., 1.]) => [1.1711007, 0., 0.05356275]
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

class fedjax.metrics.SequenceCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,))

Bases: Metric

Metric for total cross entropy loss for a sequence example.

Example:

example = {'y': jnp.array([1, 0, 1])}
prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
metric = SequenceCrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.2246635, weight=1) => 1.2246635
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.SequenceTokenAccuracy(target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)

Bases: Metric

Metric for token accuracy for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 1, 3, 0])}
# prediction = [1, 0, 2, 1, 3, 0].
prediction = jnp.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0],
                        [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0]])
logits_mask = (0., 0., 0., jnp.NINF)
metric = SequenceTokenAccuracy(logits_mask=logits_mask)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=3, weight=5) => 0.6

per_position_metric = SequenceTokenAccuracy(logits_mask=logits_mask, per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

logits_mask

Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.

Type:

Optional[Tuple[float, …]]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

class fedjax.metrics.SequenceTokenTopKAccuracy(k, target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)

Bases: Metric

Metric for token top k accuracy for a sequence example.

For more information on the top k accuracy metric, refer to the TopKAccuracy docstring.

Example:

example = {'y': jnp.array([1, 2, 2, 1, 3, 0])}
prediction = jnp.array([[0, 1, 0.5, 0], [1, 0.5, 0, 0], [0.8, 0, 0.7, 0],
                        [0.5, 1, 0, 0], [0, 0.5, 0, 1], [0.5, 0, 0.9, 0]])
logits_mask = (0., 0., 0., jnp.NINF)
metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=3, weight=5) => 0.6

per_position_metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask, per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
k

Number of top elements to look at for computing accuracy.

Type:

int

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

logits_mask

Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.

Type:

Optional[Tuple[float, …]]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

class fedjax.metrics.SequenceTokenCount(target_key='y', masked_target_values=(0,))

Bases: Metric

Metric for count of non masked tokens for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = SequenceTokenCount(masked_target_values=(0, 2))
print(metric.evaluate_example(example, prediction))
# SumStat(accum=3) => 3
target_key

Key name in example for target.

Type:

str

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.SequenceCount(target_key='y', masked_target_values=(0,))

Bases: Metric

Metric for count of non masked sequences.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
empty_example = {'y': jnp.array([0, 0, 0, 0, 0, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = metrics.SequenceCount(masked_target_values=(0, 2))
print(metric.evaluate_example(example, prediction))
# SumStat(accum=1)
print(metric.evaluate_example(empty_example, prediction))
# SumStat(accum=0)
target_key

Key name in example for target.

Type:

str

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.SequenceTruncationRate(eos_target_value, target_key='y', masked_target_values=(0,))

Bases: Metric

Metric for truncation rate for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 4])}
truncated_example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 3])}
prediction = jnp.array([])  # Unused.
metric = SequenceTruncationRate(eos_target_value=4)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=0, weight=1) => 0
print(metric.evaluate_example(truncated_example, prediction))
# MeanStat(accum=1, weight=1) => 1
eos_target_value

Target value denoting end of sequence. Truncated sequences will not have this value.

Type:

int

target_key

Key name in example for target.

Type:

str

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.SequenceTokenOOVRate(oov_target_values, target_key='y', masked_target_values=(0,), per_position=False)

Bases: Metric

Metric for out-of-vocabulary (OOV) rate for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = SequenceTokenOOVRate(oov_target_values=(2,))
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=2, weight=5) => 0.4

per_position_metric = SequenceTokenOOVRate(oov_target_values=(2,), per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[0., 1., 1., 0., 0., 0., 0.], weight=[1., 1., 1., 1., 1., 0., 0.]) => [0. 1. 1. 0. 0. 0. 0.]
oov_target_values

Target values denoting out-of-vocabulary values.

Type:

Tuple[int, …]

target_key

Key name in example for target.

Type:

str

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

class fedjax.metrics.SequenceLength(target_key='y', masked_target_values=(0,))

Bases: Metric

Metric for length for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = SequenceLength()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=4, weight=1) => 4
target_key

Key name in example for target.

Type:

str

masked_target_values

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.PerDomainMetric(base, num_domains, domain_id_key='domain_id')

Bases: Metric

Turns a base metric into one that groups results by domain.

This is useful in algorithms such as AgnosticFedAvg.

example is expected to contain a feature named domain_id_key, which stores the integer domain id in [0, num_domains). PerDomain accumulates base ‘s Stat within each domain. If the base Metric returns a Stat whose result is of shape X, then the Stat returned by PerDomain will produce a result of shape (num_domains,) + X. See Batching Stat s for the higher rank Stat mechanism enabling this.

Example:

per_domain_accuracy = PerDomain(metrics.Accuracy(), num_domains=3)
batch_example = {
    'domain_id': jnp.array([0, 0, 1, 2]),
    'y': jnp.array([0, 1, 0, 1])
}
batch_prediction = jnp.array([[0., 1.], [2., 3.], [4., 5.], [6., 7.]])
print(
    evaluate_batch(per_domain_accuracy, batch_example,
                   batch_prediction).result())
# [0.5 0.  1. ]
class fedjax.metrics.ConfusionMatrix(num_classes, target_key='y', pred_key=None)

Bases: Metric

Metric for making a Confusion Matrix.

A confusion matrix is an nxn matrix often used to describe the performance of a classification model on a set of test data for which the true values are known. The model’s predictions are represented through the columns, and the known data values through the rows. This allows one to view in which areas the model is doing well, as well as where there is room for improvement. For each row in the confusion matrix, if there are a lot of numbers outside of the main diagonal, the model is not doing so well in respect to when it is supposed to output that row’s relative output class.

Theoretical Example:

            Predicted P     Predicted N

Actual P       TP               FN

Actual N       FP               TN

**This is for a binary classification model but the same concept applies
to any model with n outputs. Notice that the TPs and TNs will always
lie in the main diagonal of the matrix.

Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0., 1., 0.])
metric = ConfusionMatrix(num_classes=3)
print(metric.evaluate_example(example, prediction))
# SumStat(accum=DeviceArray([[0., 0., 0.],
#                            [0., 0., 0.],
#                            [0., 1., 0.]], dtype=float32)) => [[0. 0. 0.]
#                                                               [0. 0. 0.]
#                                                               [0. 1. 0.]]
target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

num_classes

Number of output classes of the model. Used to generate a matrix of shape [num_classes, num_classes].

Type:

int

fedjax.metrics.unreduced_cross_entropy_loss(targets, preds, is_sparse_targets=True)

Returns unreduced cross entropy loss.

Return type:

Array

fedjax.metrics.evaluate_batch(metric, batch_example, batch_prediction, batch_mask=None)

Evaluates a batch using a metric.

Return type:

Stat

fedjax.optimizers

Lightweight library for working with optimizers.

fedjax.optimizers.Optimizer

Wraps different optimizer libraries in a common interface.

fedjax.optimizers.create_optimizer_from_optax

Creates optimizer from optax gradient transformation chain.

fedjax.optimizers.ignore_grads_haiku

Modifies optimizer to ignore gradients for non_trainable_names.

fedjax.optimizers.adagrad

The Adagrad optimizer.

fedjax.optimizers.adam

The classic Adam optimizer.

fedjax.optimizers.rmsprop

A flexible RMSProp optimizer.

fedjax.optimizers.sgd

A canonical Stochastic Gradient Descent optimizer.


class fedjax.optimizers.Optimizer(init, apply)

Wraps different optimizer libraries in a common interface.

Works with optax.

The expected usage of Optimizer is as follows:

# One step of SGD.
params = {'w': jnp.array([1, 1, 1])}
grads = {'w': jnp.array([2, 3, 4])}
optimizer = fedjax.optimizers.sgd(learning_rate=0.1)
opt_state = optimizer.init(params)
opt_state, params = optimizer.apply(grads, opt_state, params)
print(params)
# {'w': DeviceArray([0.8, 0.7, 0.6], dtype=float32)}
init

Initializes (possibly empty) PyTree of statistics (optimizer state) given the input model parameters.

Type:

Callable[[Any], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

apply

Transforms and applies the input gradients to update the optimizer state and model parameters.

Type:

Callable[[Any, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any], Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any]]

fedjax.optimizers.create_optimizer_from_optax(opt)

Creates optimizer from optax gradient transformation chain.

Return type:

Optimizer

fedjax.optimizers.ignore_grads_haiku(optimizer, non_trainable_names)

Modifies optimizer to ignore gradients for non_trainable_names.

Non-trainable parameters will have their values set to None when passed as input into the Optimizer to prevent any updates.

NOTE: This will only work with models implemented in haiku.

Parameters:
  • optimizer (Optimizer) – Base Optimizer.

  • non_trainable_names (List[Tuple[str, str]]) – List of tuples of haiku module names and names of given entries in the module data bundle (e.g. parameter name). This list of names will be used to select the non-trainable parameters.

Return type:

Optimizer

Returns:

Optimizer that will ignore gradients for the non-trainable parameters.


fedjax.optimizers.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-06)

The Adagrad optimizer.

Adagrad is an algorithm for gradient based optimisation that anneals the learning rate for each parameter during the course of training.

WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.

References

[Duchi et al, 2011](https://jmlr.org/papers/v12/duchi11a.html)

Parameters:
  • learning_rate (Union[float, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – This is a fixed global scaling factor.

  • initial_accumulator_value (float) – Initialisation for the accumulator.

  • eps (float) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.

Return type:

Optimizer

Returns:

The corresponding Optimizer.

fedjax.optimizers.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)

The classic Adam optimizer.

Adam is an SGD variant with learning rate adaptation. The learning_rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters:
  • learning_rate (Union[float, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – This is a fixed global scaling factor.

  • b1 (float) – The exponential decay rate to track the first moment of past gradients.

  • b2 (float) – The exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

Return type:

Optimizer

Returns:

The corresponding Optimizer.

fedjax.optimizers.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)

A flexible RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

References

[Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [Graves, 2013](https://arxiv.org/abs/1308.0850)

Parameters:
  • learning_rate (Union[float, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – This is a fixed global scaling factor.

  • decay (float) – The decay used to track the magnitude of previous gradients.

  • eps (float) – A small numerical constant to avoid dividing by zero when rescaling.

  • initial_scale (float) – Initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.

  • centered (bool) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.

  • momentum (Optional[float]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether nesterov momentum is used.

Return type:

Optimizer

Returns:

The corresponding Optimizer.

fedjax.optimizers.sgd(learning_rate, momentum=None, nesterov=False)

A canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

References

[Sutskever et al, 2013](http://proceedings.mlr.press/v28/sutskever13.pdf)

Parameters:
  • learning_rate (Union[float, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – This is a fixed global scaling factor.

  • momentum (Optional[float]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether nesterov momentum is used.

Return type:

Optimizer

Returns:

The corresponding Optimizer.

fedjax.tree_util

Utilities for working with tree-like container data structures.

In JAX, the term pytree refers to a tree-like structure built out of container-like Python objects. For more details, see https://jax.readthedocs.io/en/latest/pytrees.html.

fedjax.tree_util.tree_mean(pytrees_and_weights)

Returns (weighted) mean of input trees and weights.

Parameters:

pytrees_and_weights (Iterable[Tuple[Any, float]]) – Iterable of tuples of pytrees and associated weights.

Return type:

Any

Returns:

(Weighted) mean of input trees and weights.

fedjax.tree_util.tree_weight(pytree, weight)

Weights tree leaves by weight.

Return type:

Any

fedjax.tree_util.tree_sum(pytrees)

Sums multiple trees together.

Return type:

Any

fedjax.tree_util.tree_add(left, right)

Adds two trees together.

Return type:

Any

fedjax.tree_util.tree_zeros_like(pytree)

Creates a tree with zeros with same structure as the input.

Return type:

Any

fedjax.tree_util.tree_inverse_weight(pytree, weight)

Weights tree leaves by 1 / weight.

Return type:

Any

fedjax.tree_util.tree_l2_norm(pytree)

Returns l2 norm of tree.

Return type:

Array

fedjax.tree_util.tree_size(pytree)

Returns total size of all tree leaves.

Return type:

int

fedjax.tree_util.tree_clip_by_global_norm(pytree, max_norm)

Clips a pytree of arrays using their global norm.

References

[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Parameters:
  • pytree (Any) – A pytree to be potentially clipped.

  • max_norm (float) – The maximum global norm for a pytree.

Return type:

Any

Returns:

A potentially clipped pytree.

Federated algorithm

fedjax.FederatedAlgorithm

Container for all federated algorithms.

Federated data

fedjax.FederatedData

FederatedData interface for providing access to a federated dataset.

fedjax.SubsetFederatedData

A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.

fedjax.SQLiteFederatedData

Federated dataset backed by SQLite.

fedjax.InMemoryFederatedData

A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.

fedjax.FederatedDataBuilder

FederatedDataBuilder interface.

fedjax.SQLiteFederatedDataBuilder

Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.

fedjax.ClientPreprocessor

A chain of preprocessing functions on all examples of a client dataset.

fedjax.shuffle_repeat_batch_federated_data

Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.

fedjax.padded_batch_federated_data

Padded batch all client datasets, useful for evaluation on the entire federated dataset.

Client dataset

fedjax.ClientDataset

In memory client dataset backed by numpy ndarrays.

fedjax.BatchPreprocessor

A chain of preprocessing functions on batched examples.

fedjax.buffered_shuffle_batch_client_datasets

Shuffles and batches examples from multiple client datasets.

fedjax.padded_batch_client_datasets

Batches examples from multiple client datasets.

For each client

fedjax.for_each_client

Creates a function which maps over clients.

fedjax.for_each_client_backend

A context manager for switching to a given ForEachClientBackend in the current thread.

fedjax.set_for_each_client_backend

Sets the for_each_client backend for the current thread.

Model

fedjax.Model

Container class for models.

fedjax.create_model_from_haiku

Creates Model after applying defaults and haiku specific preprocessing.

fedjax.create_model_from_stax

Creates Model after applying defaults and stax specific preprocessing.

fedjax.evaluate_model

Evaluates model for multiple batches and returns final results.

fedjax.model_grad

A standard gradient function derived from a model and an optional regularizer.

fedjax.model_per_example_loss

Convenience function for constructing a per-example loss function from a model.

fedjax.evaluate_average_loss

Evaluates the average per example loss over multiple batches.

fedjax.ModelEvaluator

Evaluates model for each client dataset, either using global params, or per client params.

fedjax.AverageLossEvaluator

Evaluates average loss for each client dataset, either using global params, or per client params.

fedjax.grad

A standard gradient function derived from per-example loss and an optional regularizer.


class fedjax.FederatedAlgorithm(init, apply)[source]

Container for all federated algorithms.

FederatedAlgorithm defines the required methods that need to be implemented by all custom federated algorithms. Defining federated algorithms in this structure will allow implementations to work seamlessly with the convenience methods in the fedjax.training API, like checkpointing.

Example toy implementation:

# Federated algorithm that just counts total number of points across clients
# across rounds.

def count_federated_algorithm():

  def init(init_count):
    return {'count': init_count}

  def apply(state, clients):
    count = 0
    client_diagnostics = {}
    # Count sizes across clients.
    for client_id, client_dataset, _ in clients:
      # Summation across clients in one round.
      client_count = len(client_dataset)
      count += client_count
      client_diagnostics[client_id] = client_count
    # Summation across rounds.
    state = {'count': state['count'] + count}
    return state, client_diagnostics

  return FederatedAlgorithm(init, apply)

rng = jax.random.PRNGKey(0)  # Unused.
all_clients = [
    [
      (b'cid0', ClientDataset({'x': jnp.array([1, 2, 1, 2, 3, 4])}), rng),
      (b'cid1', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5])}), rng),
      (b'cid2', ClientDataset({'x': jnp.array([1, 1, 2])}), rng),
    ],
    [
      (b'cid3', ClientDataset({'x': jnp.array([1, 2, 3, 4])}), rng),
      (b'cid4', ClientDataset({'x': jnp.array([1, 1, 2, 1, 2, 3])}), rng),
      (b'cid5', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5, 6, 7])}), rng),
    ],
]
algorithm = count_federated_algorithm()
state = algorithm.init(0)
for round_num in range(2):
  state, client_diagnostics = algorithm.apply(state, all_clients[round_num])
  print(round_num, state)
  print(round_num, client_diagnostics)
# 0 {'count': 14}
# 0 {b'cid0': 6, b'cid1': 5, b'cid2': 3}
# 1 {'count': 31}
# 1 {b'cid3': 4, b'cid4': 6, b'cid5': 7}
init

Initializes the ServerState. Typically, the input to this method will be the initial model Params. This should only be run once at the beginning of training.

Type:

Callable[[…], Any]

apply

Completes one round of federated training given an input ServerState and a sequence of tuples of client identifier, client dataset, and client rng. The output will be a new, updated ServerState and accumulated per step results keyed by client identifier (e.g. train metrics).

Type:

Callable[[Any, Sequence[Tuple[bytes, fedjax.core.client_datasets.ClientDataset, jax.Array]]], Tuple[Any, Mapping[bytes, Any]]]


class fedjax.FederatedData[source]

FederatedData interface for providing access to a federated dataset.

A FederatedData object serves as a mapping from client ids to client datasets and client metadata.

Access methods with better I/O efficiency

For large federated datasets, it is not feasible to load all client datasets into memory at once (whereas loading a single client dataset is assumed to be feasible). Different implementations exist for different on disk storage formats. Since sequential read is much faster than random read for most storage technologies, FederatedData provides two types of methods for accessing client datasets,

  1. clients() and shuffled_clients() are sequential read friendly, and thus recommended whenever appropriate.

  2. get_clients() requires random read, but prefetching is possible. This should be preferred over get_client().

  3. get_client() is usually the slowest way of accessing client datasets, and is mostly intended for interactive exploration of a small number of clients.

Preprocessing

ClientDataset produced by FederatedData can hold a BatchPreprocessor, customizable via preprocess_batch(). Additionally, another “client” level ClientPreprocessor, customizable via preprocess_client(), can be used to apply transformations on examples from the entire client dataset before a ClientDataset is constructed.

abstract client_ids()[source]

Returns an iterator of client ids as bytes.

There is no requirement on the order of iteration.

Return type:

Iterator[bytes]

abstract client_size(client_id)[source]

Returns the number of examples in a client dataset.

Return type:

int

abstract client_sizes()[source]

Returns an iterator of all (client id, client size) pairs.

This is often more efficient than making multiple client_size() calls. There is no requirement on the order of iteration.

Return type:

Iterator[Tuple[bytes, int]]

abstract clients()[source]

Iterates over clients in a deterministic order.

Implementation can choose whatever order that makes iteration efficient.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

abstract get_client(client_id)[source]

Gets one single client dataset.

Prefer clients(), shuffled_clients(), or get_clients() when possible.

Parameters:

client_id (bytes) – Client id to load.

Return type:

ClientDataset

Returns:

The corresponding ClientDataset.

abstract get_clients(client_ids)[source]

Gets multiple clients in order with one call.

Clients are returned in the order of client_ids.

Parameters:

client_ids (Iterable[bytes]) – Client ids to load.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

Returns:

Iterator.

abstract num_clients()[source]

Returns the number of clients.

If it is too expensive or otherwise impossible to obtain the result, an implementation may raise an exception.

Return type:

int

abstract preprocess_batch(fn)[source]

Registers a preprocessing function to be called after batching in ClientDatasets.

Return type:

FederatedData

abstract preprocess_client(fn)[source]

Registers a preprocessing function to be called on all examples of a client before passing them to construct a ClientDataset.

Return type:

FederatedData

abstract shuffled_clients(buffer_size, seed=None)[source]

Iterates over clients with a repeated buffered shuffling.

Shuffling should use a buffer size of at least buffer_size clients. The iteration should repeat forever, with usually a different order in each pass.

Parameters:
  • buffer_size (int) – Buffer size for shuffling.

  • seed (Optional[int]) – Optional random number generator seed.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

Returns:

Iterator.

abstract slice(start=None, stop=None)[source]

Returns a new FederatedData restricted to client ids in the given range.

The returned FederatedData includes clients whose ids are,

  • Greater than or equal to start when start is not None;

  • Less than stop when stop is not None.

Parameters:
  • start (Optional[bytes]) – Start of client id range.

  • stop (Optional[bytes]) – Stop of client id range.

Return type:

FederatedData

Returns:

FederatedData.

class fedjax.SubsetFederatedData(base, client_ids, validate=True)[source]

Bases: FederatedData

A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.

This is useful when we wish to create a smaller FederatedData out of arbitrary client ids, where slicing is not possible.

__init__(base, client_ids, validate=True)[source]

Initializes the subset federated dataset.

Parameters:
  • base (FederatedData) – Base concrete FederatedData.

  • client_ids (Iterable[bytes]) – Client ids to include in the subset. All client ids must be in base.client_ids(), otherwise the behavior of SubsetFederatedData is undefined when validate=False.

  • validate – Whether to validate client ids.

class fedjax.SQLiteFederatedData(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Bases: FederatedData

Federated dataset backed by SQLite.

The SQLite database should contain a table named “federated_data” created with the following command:

CREATE TABLE federated_data (
  client_id BLOB NOT NULL PRIMARY KEY,
  data BLOB NOT NULL,
  num_examples INTEGER NOT NULL
);

where,

  • client_id is the bytes client id.

  • data is the serialized client dataset examples.

  • num_examples is the number of examples in the client dataset.

By default we use zlib compressed msgpack blobs for data (see decompress_and_deserialize()).

__init__(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
static new(path, parse_examples=<function decompress_and_deserialize>)[source]

Opens a federated dataset stored as an SQLite3 database.

Parameters:
  • path (str) – Path to the SQLite database file.

  • parse_examples (Callable[[bytes], Mapping[str, ndarray]]) – Function for deserializing client dataset examples.

Return type:

SQLiteFederatedData

Returns:

SQLite3DataSource.

class fedjax.InMemoryFederatedData(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Bases: FederatedData

A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.

This is useful when we wish to create a smaller FederatedData that fits in memory. Here is a simple example to create a fedjax.InMemoryFederatedData,

client_a_data = {
    'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    'y': np.array([7, 8])
}
client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}

fedjax.InMemoryFederatedData(client_to_data_mapping)
Returns:

A fedjax.InMemoryDataset corresponding to client_to_data_mapping.

__init__(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Initializes the in memory federated dataset.

Data of each client is a mapping from feature names to numpy arrays. For example, for emnist image classification, {‘x’: X, ‘y’: y}, where X is a matrix of shape (num_data_points, 28, 28) and y is a matrix of shape (num_data_points).

Parameters:
  • client_to_data_mapping (Mapping[bytes, Mapping[str, ndarray]]) – A mapping from client_id to data of each client.

  • preprocess_client (ClientPreprocessor) – federated_data.ClientPreprocessor to preprocess each client data.

  • preprocess_batch (BatchPreprocessor) – client_datasets.BatchPreprocessor to preprocess batch of data.

class fedjax.FederatedDataBuilder[source]

FederatedDataBuilder interface.

To be implemented as a context manager for building file formats from pairs of client IDs and client NumPy examples.

It is relevant to note that the add method below does not specify any raised exceptions. One could imagine some formats where add can fail in some way: out-of-order or duplicate inputs, remote files and network failures, individual entries too big for a format, etc. In order to address this we let implementations throw whatever they see relevant and fit to their particular use cases. The same is relevant when it comes to the __init__, __enter__, and __exit__ methods, where implementations are left with the responsibility of raising exceptions as they see fit to their particular use cases. For example, if an invalid file path is passed, or there were any issues finalizing the builder, etc.

Eg of end behavior:

with FederatedDataBuilder(path) as builder:
  builder.add(b'k1', np.array([b'v1'], dtype=np.object))
  builder.add(b'k2', np.array([b'v2'], dtype=np.object))
abstract add_many(client_ids_examples)[source]

Bulk adds multiple client IDs and client NumPy examples pairs to file format.

Parameters:

client_ids_examples (Iterable[Tuple[bytes, Mapping[str, ndarray]]]) – Iterable of tuples of client id and examples.

class fedjax.SQLiteFederatedDataBuilder(path)[source]

Bases: FederatedDataBuilder

Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.

__init__(path)[source]

Initializes SQLiteBuilder by opening a connection and setting up the database with columns.

Parameters:

path (str) – Path of file to write to (e.g. /tmp/sqlite_federated_data.sqlite).

class fedjax.ClientPreprocessor(fns=())[source]

A chain of preprocessing functions on all examples of a client dataset.

This is very similar to fedjax.BatchPreprocessor, with the main difference being that ClientPreprocessor also takes client_id as input.

See the discussion in fedjax.BatchPreprocessor regarding when to use which.

__call__(client_id, examples)[source]

Call self as a function.

Return type:

Mapping[str, ndarray]

__init__(fns=())[source]
append(fn)[source]

Creates a new ClientPreprocessor with fn added to the end.

Return type:

ClientPreprocessor

fedjax.shuffle_repeat_batch_federated_data(fd, batch_size, client_buffer_size, example_buffer_size, seed=None)[source]

Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.

Shuffling is done using two levels of buffered shuffling, first at the client level, then at the example level.

This produces an infinite stream of batches. itertools.islice() can be used to cap the number of batches, if so desired.

Parameters:
  • fd (FederatedData) – Federated dataset.

  • batch_size (int) – Desired batch size.

  • client_buffer_size (int) – Buffer size for client level shuffling.

  • example_buffer_size (int) – Buffer size for example level shuffling.

  • seed (Optional[int]) – Optional RNG seed.

Yields:

Batches of preprocessed examples.

Return type:

Iterator[Mapping[str, ndarray]]

fedjax.padded_batch_federated_data(fd, hparams=None, **kwargs)[source]

Padded batch all client datasets, useful for evaluation on the entire federated dataset.

Parameters:
Yields:

Batches of preprocessed examples.

Return type:

Iterator[Mapping[str, ndarray]]

class fedjax.RepeatableIterator(base)[source]

Repeats a base iterable after the end of the first pass of iteration.

Because this is a stateful object, it is not thread safe, and all usual caveats about accessing the same iterator at different locations apply. For example, if we make two map calls to the same RepeatableIterator, we must make sure we do not interleave next() calls on these. For example, the following is safe because we finish iterating on m1 before starting to iterate on m2.,

it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
# We finish iterating on m1 before starting to iterate on m2.
print(list(m1), list(m2))
# [1, 2, 3, 4] [0, 1, 4, 9]

Whereas interleaved access leads to confusing results,

it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
print(next(m1), next(m2))
# 1 1
print(next(m1), next(m2))
# 3 9
print(next(m1), next(m2))
# StopIteration!

In the first pass of iteration, values fetched from the base iterator will be copied onto an internal buffer (except for a few builtin containers where copying is unnecessary). When each pass of iteration finishes (i.e. when __next__() raises StopIteration), the iterator resets itself to the start of the buffer, thus allowing a subsequent pass of repeated iteration.

In most cases, if repeated iterations are required, it is sufficient to simply copy values from an iterator into a list. However, sometimes an iterator produces values via potentially expensive I/O operations (e.g. loading client datasets), RepeatableIterator can interleave I/O and JAX compute to decrease accelerator idle time in this case.


Preprocessing and batching operations over client datasets.

Column based representation

The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).

We use a column based representation when loading a dataset into memory.

  • Each column is a NumPy array x of rank at least 1, where x[i, ...] is the value of this feature for the i-th example.

  • The complete set of examples is a dict-like object, from str feature names, to the corresponding column values.

Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch. In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.

Preprocessor

Preprocessing on a batch of examples can be easily done via a chain of functions. A Preprocessor object holds the chain of functions, and applies the transformation on a batch of examples.

ClientDataset: examples + preprocessor

A ClientDataset is simply some examples in the column based representation, accompanied by a Preprocessor. Its batch() method produces batches of examples in a sequential order, suitable for evaluation. Its shuffle_repeat_batch() method adds shuffling and repeating, making it suitable for training.

class fedjax.ClientDataset(raw_examples, preprocessor=BatchPreprocessor(()))[source]

In memory client dataset backed by numpy ndarrays.

Custom preprocessing on batches can be added via a preprocessor. A ClientDataset is stored as the unpreprocessed raw_examples, along with its preprocessor.

This is only intended for efficient access to small datasets that fit in memory.

__getitem__(index)[source]

Returns a new ClientDataset with sliced raw examples.

Return type:

ClientDataset

__init__(raw_examples, preprocessor=BatchPreprocessor(()))[source]
__len__()[source]

Returns the number of raw examples in this dataset.

Return type:

int

all_examples()[source]

Returns the result of feeding all raw examples through the preprocessor.

This is mostly intended for interactive exploration of a small subset of a client dataset. For example, to see the first 4 examples in a client dataset,

dataset = ClientDataset(my_raw_examples, my_preprocessor)
dataset[:4].all_examples()
Return type:

Mapping[str, ndarray]

Returns:

Preprocessed examples from all the raw examples in this client dataset.

batch(hparams=None, **kwargs)[source]

Produces preprocessed batches in a fixed sequential order.

The final batch may contain fewer than batch_size examples. If used directly, that may result in a large number of JIT recompilations. Therefore we recommended using padded_batch instead in most scenarios.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code if you have to use batch (prefer padded_batch() if possible). Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.batch(batch_size=2)
    # Overrides the default drop_remainder value.
    client_dataset.batch(hparams, drop_remainder=True)
    
Parameters:
  • hparams (Optional[BatchHParams]) – Batching hyperparameters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

padded_batch(hparams=None, **kwargs)[source]

Produces preprocessed padded batches in a fixed sequential order.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.padded_batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.padded_batch(batch_size=2)
    # Overrides the default num_batch_size_buckets value.
    client_dataset.padded_batch(hparams, num_batch_size_buckets=2)
    

When the number of examples in the dataset is not a multiple of batch_size, the final batch may be smaller than batch_size. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled by num_batch_size_buckets.

All batches contain an extra bool feature keyed by EXAMPLE_MASK_KEY. batch[EXAMPLE_MASK_KEY][i] tells us whether the i-th example in this batch is an actual example (batch[EXAMPLE_MASK_KEY][i] == True), or a padding example (batch[EXAMPLE_MASK_KEY][i] == False).

We repeatedly halve the batch size up to num_batch_size_buckets-1 times, until we find the smallest one that is also >= the size of the final batch. Therefore if batch_size < 2^num_batch_size_buckets, fewer bucket sizes will be actually used.

Parameters:
  • hparams (Optional[PaddedBatchHParams]) – Batching hyperparameters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

shuffle_repeat_batch(hparams=None, **kwargs)[source]

Produces preprocessed batches in a shuffled and repeated order.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.shuffle_repeat_batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.shuffle_repeat_batch(batch_size=2)
    # Overrides the default num_epochs value.
    client_dataset.shuffle_repeat_batch(hparams, num_epochs=2)
    

Shuffling is done without replacement, therefore for a dataset of N examples, the first ceil(N/batch_size) batches are guarranteed to cover the entire dataset.

By default the iteration stops after the first epoch. The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder) combination:

  • If both num_epochs and num_steps are None, the shuffle-repeat process continues forever.

  • If num_epochs is set and num_steps is None, as few batches as needed to go over the dataset this many passes are produced. Further,

    • If drop_remainder is False (the default), the final batch is filled with additionally sampled examples to contain batch_size examples.

    • If drop_remainder is True, the final batch is dropped if it contains fewer than batch_size examples. This may result in examples being skipped when num_epochs=1.

  • If num_steps is set and num_steps is None, exactly this many batches are produced. drop_remainder has no effect in this case.

  • If both num_epochs and num_steps are set, the fewer number of batches between the two conditions are produced.

If reproducible iteration order is desired, a fixed seed can be used. When seed is None, repeated iteration over the same object may produce batches in a different order.

Unlike batch() or padded_batch(), batches from shuffle_repeat_batch() always contain exactly batch_size examples. Also unlike TensorFlow, that holds even when drop_remainder=False.

Parameters:
  • hparams (Optional[ShuffleRepeatBatchHParams]) – Batching hyperparamters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

class fedjax.BatchPreprocessor(fns=())[source]

A chain of preprocessing functions on batched examples.

BatchPreprocessor holds a chain of preprocessing functions, and applies them in order on batched examples. Each individual preprocessing function operates over multiple examples, instead of just 1 example. For example,

preprocessor = BatchPreprocessor([
  # Flattens `pixels`.
  lambda x: {**x, 'pixels': x['pixels'].reshape([-1, 28 * 28])},
  # Introduce `binary_label`.
  lambda x: {**x, 'binary_label': x['label'] % 2},
])
fake_emnist = {
  'pixels': np.random.uniform(size=(10, 28, 28)),
  'label': np.random.randint(10, size=(10,))
}
preprocessor(fake_emnist)
# Produces a dict of [10, 28*28] "pixels", [10,] "label" and "binary_label".

Given a BatchPreprocessor, a new BatchPreprocessor can be created with an additional preprocessing function appended to the chain,

# Continuing from the previous example.
new_preprocessor = preprocessor.append(
  lambda x: {**x, 'sum_pixels': np.sum(x['pixels'], axis=1)})
new_preprocessor(fake_emnist)
# Produces a dict of [10, 28*28] "pixels", [10,] "sum_pixels", "label" and
# "binary_label".

The main difference of this preprocessor and fedjax.ClientPreprocessor is that fedjax.ClientPreprocessor also takes client_id as input. Because of the identical representation between batched examples and all examples in a client dataset, certain preprocessing can be done with either BatchPreprocessor or ClientPreprocessor.

Examples of preprocessing possible at either the client dataset level, or the batch level

Such preprocessing is deterministic, and strictly per-example.

  • Casting a feature from int8 to float32.

  • Adding a new feature derived from existing features.

  • Remove a feature (although the better place to do so is at the dataset level).

A simple rule for deciding where to carry out the preprocessing in this case is the following,

  • Does this make batching cheaper (e.g. removing features)? If so, do it at the dataset level.

  • Otherwise, do it at the batch level.

Assuming preprocessing time is linear in the number of examples, preprocessing at the batch level has the benefit of evenly distributing host compute work, which may overlap better with asynchronous JAX compute work on GPU/TPU.

Examples of preprocessing only possible at the batch level

  • Data augmentation (e.g. random cropping).

  • Padding at the batch size dimension.

Examples of preprocessing only possible at the dataset level

  • Those that require knowing the client id.

  • Capping the number of examples.

  • Altering what it means to be an example: e.g. in certain language model setups, sentences are concatenated and then split into equal sized chunks.

__init__(fns=())[source]
append(fn)[source]

Creates a new BatchPreprocessor with fn added to the end.

Return type:

BatchPreprocessor

fedjax.buffered_shuffle_batch_client_datasets(datasets, batch_size, buffer_size, rng)[source]

Shuffles and batches examples from multiple client datasets.

This just makes 1 pass over the examples. To achieve repeated iterations, create an infinite shuffled stream of datasets first (e.g. using buffered_shuffle()).

Parameters:
  • datasets (Iterable[ClientDataset]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.

  • batch_size (int) – Desired batch size.

  • buffer_size (int) – Number of examples to buffer during shuffling.

  • rng (RandomState) – Source of randomness.

Yields:

Batches of examples. For a finite stream of datasets, the final batch might be smaller than batch_size.

Raises:
  • ValueError – If any 2 client datasets have different Preprocessors.

  • ValueError – If any 2 client datasets have different features.

Return type:

Iterator[Mapping[str, ndarray]]

fedjax.padded_batch_client_datasets(datasets, hparams=None, **kwargs)[source]

Batches examples from multiple client datasets.

This is useful when we want to evaluate on the combined dataset consisting of multiple client datasets. Unlike batching each client dataset individually, we can reduce the number of batches smaller than batch_size.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(datasets, hparams):
      for batch in padded_batch_client_datasets(datasets, hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    padded_batch_client_datasets(datasets, hparams)
    # Overrides the default num_batch_size_buckets value.
    padded_batch_client_datasets(datasets, hparams, num_batch_size_buckets=2)
    
Parameters:
  • datasets (Iterable[ClientDataset]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.

  • hparams (Optional[PaddedBatchHParams]) – Batching hyperparams like those in ClientDataset.padded_batch().

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Yields:

Batches of examples. The final batch might be padded. All batches contain a bool feature keyed by EXAMPLE_MASK_KEY.

Raises:
  • ValueError – If any 2 client datasets have different Preprocessors.

  • ValueError – If any 2 client datasets have different features.

Return type:

Iterator[Mapping[str, ndarray]]


fedjax.for_each_client(client_init, client_step, client_final=<function <lambda>>, with_step_result=False)[source]

Creates a function which maps over clients.

For example, for_each_client could be used to define how to run client updates for each client in a federated training round. Another common use case of for_each_client is to run evaluation per client for a given set of model parameters.

The underlying backend for for_each_client is customizable. For example, if multiple devies are available (e.g. TPU), a jax.pmap() based backend can be used to parallelize across devices. It’s also possible to manually specify which backend to use (for debugging).

The expected usage of for_each_client is as follows:

# Map over clients and count how many points are greater than `limit` for
# each client. Each client also has a different `start` that is specified
# via client input.

def client_init(shared_input, client_input):
  client_step_state = {
      'limit': shared_input['limit'],
      'count': client_input['start']
  }
  return client_step_state

def client_step(client_step_state, batch):
  num = jnp.sum(batch['x'] > client_step_state['limit'])
  client_step_state = {
      'limit': client_step_state['limit'],
      'count': client_step_state['count'] + num
  }
  return client_step_state

def client_final(shared_input, client_step_state):
  del shared_input  # Unused.
  return client_step_state['count']

# Three clients with different data and starting counts.
# clients = [(client_id, client_batches, client_input)]
clients = [
    (b'cid0',
    [{'x': jnp.array([1, 2, 3, 4])}, {'x': jnp.array([1, 2, 3])}],
    {'start': jnp.array(2)}),
    (b'cid1',
    [{'x': jnp.array([1, 2])}, {'x': jnp.array([1, 2, 3, 4, 5])}],
    {'start': jnp.array(0)}),
    (b'cid2',
    [{'x': jnp.array([1])}],
    {'start': jnp.array(1)}),
]
shared_input = {'limit': jnp.array(2)}

func = fedjax.for_each_client(client_init, client_step, client_final)
print(list(func(shared_input, clients)))
# [(b'cid0', 5), (b'cid1', 3), (b'cid2', 1)]

Here’s the same example with per step results.

# We'll also keep track of the `num` per step in our step results.

def client_step_with_result(client_step_state, batch):
  num = jnp.sum(batch['x'] > client_step_state['limit'])
  client_step_state = {
      'limit': client_step_state['limit'],
      'count': client_step_state['count'] + num
  }
  client_step_result = {'num': num}
  return client_step_state, client_step_result

func = fedjax.for_each_client(
    client_init, client_step_with_result, client_final, with_step_result=True)
print(list(func(shared_input, clients)))
# [
#   (b'cid0', 5, [{'num': 2}, {'num': 1}]),
#   (b'cid1', 3, [{'num': 0}, {'num': 3}]),
#   (b'cid2', 1, [{'num': 0}]),
# ]
Parameters:
  • client_init (Callable[[Any, Any], Any]) – Function that initializes the internal intermittent client step state from the share input and per client input. The shared input contains information like the global model parameters that are shared across all clients. The per client input is per client information. The initialized internal client step state is fed as intermittent input and output from client_step and client_final. This client step state usually contains the model parameters and optimizer state for each client that are updated at each client_step. This will be run once for each client.

  • client_step (Union[Callable[[Any, Mapping[str, Array]], Tuple[Any, Any]], Callable[[Any, Mapping[str, Array]], Any]]) – Function that takes the client step state and a batch of examples as input and outputs a (possibly updated) client step state. Optionally, per step results can also be returned as the second element if with_step_result is True. Per step results are usually diagnostics like gradient norm. This will be run for each batch for each client.

  • client_final (Callable[[Any, Any], Any]) – Function that applies the final transformation on the internal client step state to the desired final client output. More meaningful transformations can be done here, like model update clipping. Defaults to just returning the client step state. This will be run once for each client.

  • with_step_result (bool) – Indicates whether client_step returns a pair where the first element is considered the client output and the second element is the client step result.

Returns:

A for each client function that takes shared_input and the per client inputs as tuple (client_id, batched_client_data, client_rng) to map over and returns the outputs per client as specified in client_final along with optional per client per step results.

fedjax.for_each_client_backend(backend)[source]

A context manager for switching to a given ForEachClientBackend in the current thread.

Example:

with for_each_client_backend('pmap'):
  # We will be using the pmap based for_each_client backend within this block.
  pass
# We will be using the default for_each_client backend from now on.
Parameters:

backend (Union[ForEachClientBackend, str, None]) – See set_for_each_client_backend().

Yields:

Nothing.

fedjax.set_for_each_client_backend(backend)[source]

Sets the for_each_client backend for the current thread.

Parameters:

backend (Union[ForEachClientBackend, str, None]) –

One of the following,

  • None: uses the default backend for the current environment.

  • ’debug’: uses the debugging backend.

  • ’jit’: uses the JIT backend.

  • ’pmap’: uses the pmap-based backend.

  • A concrete ForEachClientBackend object.


class fedjax.Model(init, apply_for_train, apply_for_eval, train_loss, eval_metrics)[source]

Container class for models.

Model exists to provide easy access to predefined neural network models. It is meant to contain all the information needed for standard centralized training and evaluation. Non-standard training methods can be built upon the information avaiable in Model along with any additional information (e.g. interpolation can be implemented as a composition of two models along with an interpolation weight).

Works for Haiku and jax.example_libraries.stax.

The expected usage of Model is as follows:

# Training.
step_size = 0.1
rng = jax.random.PRNGKey(0)
params = model.init(rng)

def loss(params, batch, rng):
  preds = model.apply_for_train(params, batch, rng)
  return jnp.sum(model.train_loss(batch, preds))

grad_fn = jax.grad(loss)
for batch in batches:
  rng, use_rng = jax.random.split(rng)
  grads = grad_fn(params, batch, use_rng)
  params = jax.tree_util.tree_map(lambda a, b: a - step_size * b,
                                       params, grads)

# Evaluation.
print(fedjax.evaluate_model(model, params, batches))
# Example output:
# {'loss': 2.3, 'accuracy': 0.2}

The following is an example using Model compositionally as a building block to impelement model interpolation:

def interpolate(model_1, model_2, init_weight):

  @jax.jit
  def init(rng):
    rng_1, rng_2 = jax.random.split(rng)
    params_1 = model_1.init(rng_1)
    params_2 = model_2.init(rng_2)
    return params_1, params_2, init_weight

  @jax.jit
  def apply_for_train(params, input, rng):
    rng_1, rng_2 = jax.random.split(rng)
    params_1, params_2, weight = params
    return (model_1.apply_for_train(params_1, input, rng_1) * weight +
            model_2.apply_for_train(params_1, input, rng_2) * (1 - weight))

  @jax.jit
  def apply_for_eval(params, input):
    params_1, params_2, weight = params
    return (model_1.apply_for_eval(params_1, input) * weight +
            model_2.apply_for_eval(params_2, input) * (1 - weight))

  return fedjax.Model(init,
                      apply_for_train,
                      apply_for_eval,
                      model_1.train_loss,
                      model_1.eval_metrics)

model = interpolate(model_1, model_2, init_weight=0.5)
init

Initialization function that takes a seed PRNGKey and returns a PyTree of initialized parameters (i.e. model weights). These parameters will be passed as input into apply_for_train() and apply_for_eval(). Any trainable weights for a model that are modified in the training loop should be contained inside of these parameters.

Type:

Callable[[jax.Array], Any]

apply_for_train

Function that takes the parameters PyTree, batch of examples, and PRNGKey as inputs and outputs the model predictions for training that are then passed into train_loss(). This considers strategies such as dropout.

Type:

Callable[[Any, Mapping[str, jax.Array], jax.Array], jax.Array]

apply_for_eval

Function that usually takes the parameters PyTree and batch of examples as inputs and outputs the model predictions for evaluation that are then passed to eval_metrics. This is defined separately from apply_for_train() to avoid having to specify inputs like PRNGKey that are not used in evaluation.

Type:

Callable[[Any, Mapping[str, jax.Array]], Union[jax.Array, Mapping[str, jax.Array]]]

train_loss

Loss function for training that takes batch of examples and model output from apply_for_train() as input that outputs per example loss. This will typically called inside a jax.grad() wrapped function to compute gradients.

Type:

Callable[[Mapping[str, jax.Array], jax.Array], jax.Array]

eval_metrics

Ordered mapping of evaluation metric names to Metric. These Metric s are defined for single examples and will be used in evaluate_model()

Type:

Mapping[str, fedjax.core.metrics.Metric]

fedjax.create_model_from_haiku(transformed_forward_pass, sample_batch, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None)[source]

Creates Model after applying defaults and haiku specific preprocessing.

Parameters:
  • transformed_forward_pass (Transformed) – Transformed forward pass from hk.transform()

  • sample_batch (Mapping[str, Array]) – Example input used to determine model parameter shapes.

  • train_loss (Callable[[Mapping[str, Array], Array], Array]) – Loss function for training that outputs per example loss.

  • eval_metrics (Optional[Mapping[str, Metric]]) – Mapping of evaluation metric names to Metric. These metrics are defined for single examples and will be consumed in evaluate_model().

  • train_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for training.

  • eval_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for evaluation.

Return type:

Model

Returns:

Model

fedjax.create_model_from_stax(stax_init, stax_apply, sample_shape, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None, input_key='x')[source]

Creates Model after applying defaults and stax specific preprocessing.

Parameters:
  • stax_init (Callable[…, Any]) – Initialization function returned from stax.serial().

  • stax_apply (Callable[…, Array]) – Model forward_pass pass function returned from stax.serial.

  • sample_shape (Tuple[int, …]) – The expected shape of the input to the model.

  • train_loss (Callable[[Mapping[str, Array], Array], Array]) – Loss function for training that outputs per example loss.

  • eval_metrics (Optional[Mapping[str, Metric]]) – Mapping of evaluation metric names to Metric. These metrics are defined for single examples and will be consumed in evaluate_model().

  • train_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for training.

  • eval_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for evaluation.

  • input_key (str) – Key name for the input in batch mapping.

Return type:

Model

Returns:

Model

fedjax.evaluate_model(model, params, batches)[source]

Evaluates model for multiple batches and returns final results.

This is the recommended way to compute evaluation metrics for a given model.

Parameters:
  • model (Model) – Model container.

  • params (Any) – Pytree of model parameters to be evaluated.

  • batches (Iterable[Mapping[str, Array]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch() ).

Return type:

Dict[str, Array]

Returns:

A dictionary of evaluation Metric results.

fedjax.model_grad(model, regularizer=None)[source]

A standard gradient function derived from a model and an optional regularizer.

The scalar loss function being differentiated is simply:

mean(model’s per-example loss) + regularizer term

The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.

Parameters:
  • model (Model) – A Model.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer.

Return type:

Callable[[Any, Mapping[str, Array], Array], Any]

Returns:

A function from (params, batch_example, rng) to gradients.

fedjax.model_per_example_loss(model)[source]

Convenience function for constructing a per-example loss function from a model.

Parameters:

model (Model) – Model.

Return type:

Callable[[Any, Mapping[str, Array], Array], Array]

Returns:

A function from (params, batch_example, rng) to a vector of loss values for each example in the batch.

fedjax.evaluate_average_loss(params, batches, rng, per_example_loss, regularizer=None)[source]

Evaluates the average per example loss over multiple batches.

Parameters:
  • params (Any) – PyTree of model parameters to be evaluated.

  • batches (Iterable[Mapping[str, Array]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch).

  • rng (Array) – Initial PRNGKey for making per_example_loss calls.

  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – Per example loss function.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer function.

Return type:

Array

Returns:

The average per example loss, plus the regularizer term when specified.

class fedjax.ModelEvaluator(model)[source]

Evaluates model for each client dataset, either using global params, or per client params.

To evaluate a Model on a single dataset, use evaluate_model() instead.

__init__(model)[source]
evaluate_global_params(params, clients)[source]

Evaluates batches from each client using global params.

Parameters:
  • params (Any) – Model params to evaluate.

  • clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]]]]) – Client batches.

Yields:

Pairs of the client id and a dictionary of evaluation Metric results for each client.

Return type:

Iterator[Tuple[bytes, Dict[str, Array]]]

evaluate_per_client_params(clients)[source]

Evaluates batches from each client using per client params.

Parameters:

clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Any]]) – Client batches and the per client params.

Yields:

Pairs of the client id and a dictionary of evaluation Metric results for each client.

Return type:

Iterator[Tuple[bytes, Dict[str, Array]]]

class fedjax.AverageLossEvaluator(per_example_loss, regularizer=None)[source]

Evaluates average loss for each client dataset, either using global params, or per client params.

The average loss is defined as the average per example loss, plus the regularizer term when specified. To evaluate average loss on a single dataset, use evaluate_average_loss() instead.

__init__(per_example_loss, regularizer=None)[source]
evaluate_global_params(params, clients)[source]

Evaluates batches from each client using global params.

Parameters:
  • params (Any) – Model params to evaluate.

  • clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Array]]) – Client batches.

Yields:

Pairs of the client id and the client’s average loss.

Return type:

Iterator[Tuple[bytes, Array]]

evaluate_per_client_params(clients)[source]

Evaluates batches from each client using per client params.

Parameters:

clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Array, Any]]) – Client batches and the per client params.

Yields:

Pairs of the client id and the client’s average loss.

Return type:

Iterator[Tuple[bytes, Array]]

fedjax.grad(per_example_loss, regularizer=None)[source]

A standard gradient function derived from per-example loss and an optional regularizer.

The scalar loss function being differentiated is simply:

mean(per-example loss) + regularizer term

The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer that only depends on params.

Return type:

Callable[[Any, Mapping[str, Array], Array], Any]

Returns:

A function from (params, batch_example, rng) to gradients.

fedjax.aggregators

FedJAX aggregators.

class fedjax.aggregators.Aggregator(init, apply)[source]

Interface for algorithms to aggregate.

This interface defines aggregator algorithms that are used at each round. Aggregator state contains any round specific parameters (e.g. number of bits) that will be passed from round to round. This state is initialized by init and passed as input into and returned as output from aggregate. We strongly recommend using fedjax.dataclass to define state as this provides immutability, type hinting, and works by default with JAX transformations.

The expected usage of Aggregator is as follows:

aggregator = mean_aggregator()
state = aggregator.init()
for i in range(num_rounds):
  clients_params_and_weights = compute_client_outputs(i)
  aggregated_params, state = aggregator.apply(clients_params_and_weights,
                                              state)
init

Returns initial state of aggregator.

Type:

Callable[[], Any]

apply

Returns the new aggregator state and aggregated params.

Type:

Callable[[Iterable[Tuple[bytes, Any, float]], Any], Tuple[Any, Any]]

Mean

fedjax.aggregators.mean_aggregator()[source]

Builds (weighted) mean aggregator.

Return type:

Aggregator

Quantization

fedjax.aggregators.uniform_stochastic_quantizer(num_levels, rng, encode_algorithm=None)[source]

Returns (weighted) mean of input uniformly quantized trees using the

uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.

Parameters:
  • num_levels (int) – number of levels of quantization.

  • rng (Array) – PRNGKey used for compression:

  • encode_algorithm (Optional[str]) – None or arithmetic

Return type:

Aggregator

Returns:

Compression aggregator.

fedjax.algorithms

Federated learning algorithm implementations.

fedjax.algorithms.agnostic_fed_avg

AgnosticFedAvg implementation.

fedjax.algorithms.fed_avg

Federated averaging implementation using fedjax.core.

fedjax.algorithms.hyp_cluster

Federated hypothesis-based clustering implementation.

fedjax.algorithms.mime

Mime implementation.

fedjax.algorithms.mime_lite

Mime Lite implementation.

AgnosticFedAvg

AgnosticFedAvg implementation.

Communication-Efficient Agnostic Federated Averaging

Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748

class fedjax.algorithms.agnostic_fed_avg.ServerState(params, opt_state, domain_weights, domain_window)[source]

State of server for AgnosticFedAvg passed between rounds.

params

A pytree representing the server model parameters.

Type:

Any

opt_state

A pytree representing the server optimizer state.

Type:

Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

domain_weights

Weights per domain applied to weight in weighted average.

Type:

jax.Array

domain_window

Sliding window keeping track of domain number of examples for the last window size rounds.

Type:

List[jax.Array]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.agnostic_fed_avg.agnostic_federated_averaging(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, domain_batch_hparams, init_domain_weights, domain_learning_rate, domain_algorithm='eg', domain_window_size=1, init_domain_window=None, regularizer=None)[source]

Builds agnostic federated averaging.

Agnostic federated averaging requires input fedjax.core.client_datasets.ClientDataset examples to contain a feature named “domain_id”, which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1).

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training.

  • client_optimizer (Optimizer) – Optimizer for local client training.

  • server_optimizer (Optimizer) – Optimizer for server update.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for client dataset for training.

  • domain_batch_hparams (PaddedBatchHParams) – Hyperparameters for client dataset domain metrics calculation.

  • init_domain_weights (Sequence[float]) – Initial weights per domain that must sum to 1.

  • domain_learning_rate (float) – Learning rate for domain weight update.

  • domain_algorithm (str) – Algorithm used to update domain weights each round. One of ‘eg’, ‘none’.

  • domain_window_size (int) – Size of sliding window keeping track of number of examples per domain over multiple rounds.

  • init_domain_window (Optional[Sequence[float]]) – Initial values for domain window. Defaults to ones.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer that only depends on params.

Return type:

FederatedAlgorithm

Returns:

FederatedAlgorithm.

Raises:

ValueError – If init_domain_weights does not sum to 1 or if init_domain_weights and init_domain_window are unequal lengths.

FedAvg

Federated averaging implementation using fedjax.core.

This is the more performant implementation that matches what would be used in the fedjax.algorithms.fed_avg . The key difference between this and the basic version is the use of fedjax.core.for_each_client

Communication-Efficient Learning of Deep Networks from Decentralized Data

H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021. https://arxiv.org/abs/2003.00295

class fedjax.algorithms.fed_avg.ServerState(params, opt_state)[source]

State of server passed between rounds.

params

A pytree representing the server model parameters.

Type:

Any

opt_state

A pytree representing the server optimizer state.

Type:

Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams)[source]

Builds federated averaging.

Parameters:
  • grad_fn (Callable[[Any, Mapping[str, Array], Array], Any]) – A function from (params, batch_example, rng) to gradients. This can be created with fedjax.core.model.model_grad().

  • client_optimizer (Optimizer) – Optimizer for local client training.

  • server_optimizer (Optimizer) – Optimizer for server update.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

Return type:

FederatedAlgorithm

Returns:

FederatedAlgorithm

HypCluster

Federated hypothesis-based clustering implementation.

Three Approaches for Personalization with Applications to Federated Learning

Yishay Mansour, Mehryar Mohri, Jae Ro, Ananda Theertha Suresh https://arxiv.org/abs/2002.10619

class fedjax.algorithms.hyp_cluster.ServerState(cluster_params, opt_states)[source]

State of server for HypCluster passed between rounds.

cluster_params

A list of pytrees representing the server model parameters per cluster.

Type:

List[Any]

opt_states

A list of pytrees representing the server optimizer state per cluster.

Type:

List[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.hyp_cluster.hyp_cluster(per_example_loss, client_optimizer, server_optimizer, maximization_batch_hparams, expectation_batch_hparams, regularizer=None)[source]

Federated hypothesis-based clustering algorithm.

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch, rng) to a vector of per example loss values.

  • client_optimizer (Optimizer) – Client side optimizer.

  • server_optimizer (Optimizer) – Server side optimizer.

  • maximization_batch_hparams (PaddedBatchHParams) – Batching hyperparameters for the maximization step.

  • expectation_batch_hparams (ShuffleRepeatBatchHParams) – Batching hyperparameters for the expectation step.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer.

Return type:

FederatedAlgorithm

Returns:

A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that init() takes a list of cluster_params, which can be obtained using random_init(), ModelKMeansInitializer, or kmeans_init() in this module.

fedjax.algorithms.hyp_cluster.kmeans_init(num_clusters, init_params, clients, trainer, train_batch_hparams, evaluator, eval_batch_hparams, rng)[source]

Initializes cluster params for HypCluster using a k-means++ variant.

See ModelKMeansInitializer for a more convenient initializer from a Model.

Given a set of input clients, we train parameters for each client. The initial cluster parameters are chosen out of this set of client parameters. At a high level, the cluster parameters are selected by choosing the client parameters that are the “furthest away” from each other (greatest difference in loss).

Parameters:
  • num_clusters (int) – Desired number of cluster centers.

  • init_params (Any) – Initial params for training each client.

  • clients (Sequence[Tuple[bytes, ClientDataset, Array]]) – Clients to train for generating candidate cluster center params.

  • trainer (ClientParamsTrainer) – Client trainer for carry out client training.

  • train_batch_hparams (ShuffleRepeatBatchHParams) – Batching hyperparameters for client training.

  • evaluator (AverageLossEvaluator) – Average loss evaluator for evaluator clients on chosen cluster center params.

  • eval_batch_hparams (PaddedBatchHParams) – Batching hyperparameters for client evaluation.

  • rng (Array) – RNG used for choosing the initial cluster center.

Return type:

List[Any]

Returns:

Cluster center params.

Raises:

ValueError if some input arguments are invalid.

fedjax.algorithms.hyp_cluster.random_init(num_clusters, init, rng)[source]

Randomly initializes cluster params.

Return type:

List[Any]

class fedjax.algorithms.hyp_cluster.ModelKMeansInitializer(model, client_optimizer, regularizer=None)[source]

Initializes cluster params for HypCluster using a k-means++ variant.

This is a thin wrapper for initializing from a Model . See kmeans_init() for a more general version of the initializer, and details of the initialization algorithm.

__init__(model, client_optimizer, regularizer=None)[source]
cluster_params(num_clusters, rng, clients, train_batch_hparams, eval_batch_hparams)[source]
Return type:

List[Any]

class fedjax.algorithms.hyp_cluster.HypClusterEvaluator(model, regularizer=None)[source]

Evaluates cluster params on a Model.

__init__(model, regularizer=None)[source]

Initializes some reusable components.

Because we need to make multiple passes over each client during evaluation, the number of clients that can be evaluated at once is limited. Therefore multiple calls to evaluate_clients() are needed to evaluate a large federated dataset. We factor out some reusable components so that the same computation can be jit compiled.

Parameters:
  • model (Model) – Model being evaluated.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer.

evaluate_clients(cluster_params, train_clients, test_clients, batch_hparams)[source]

Evaluates each client on the cluster with best average loss.

Return type:

Iterator[Tuple[bytes, Dict[str, Array]]]

Mime

Mime implementation.

Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning

Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606

class fedjax.algorithms.mime.ServerState(params, opt_state)[source]

State of server passed between rounds.

params

A pytree representing the server model parameters.

Type:

Any

opt_state

A pytree representing the base optimizer state.

Type:

Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.mime.mime(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None)[source]

Builds mime.

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.

  • base_optimizer (Optimizer) – Base optimizer to mimic.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

  • grads_batch_hparams (PaddedBatchHParams) – Hyperparameters for batching client dataset for server gradient computation.

  • server_learning_rate (float) – Server learning rate.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer that only depends on params.

Return type:

FederatedAlgorithm

Returns:

FederatedAlgorithm

MimeLite

Mime Lite implementation.

Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning

Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606

Reuses fedjax.algorithms.mime.ServerState

fedjax.algorithms.mime_lite.mime_lite(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None, client_delta_clip_norm=None)[source]

Builds mime lite.

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.

  • base_optimizer (Optimizer) – Base optimizer to mimic.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

  • grads_batch_hparams (PaddedBatchHParams) – Hyperparameters for batching client dataset for server gradient computation.

  • server_learning_rate (float) – Server learning rate.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer that only depends on params.

  • client_delta_clip_norm (Optional[float]) – Maximum allowed global norm per client update. Defaults to no clipping.

Return type:

FederatedAlgorithm

Returns:

FederatedAlgorithm

fedjax.datasets

fedjax datasets.

fedjax.datasets.cifar100

Federated cifar100.

fedjax.datasets.emnist

Federated EMNIST.

fedjax.datasets.shakespeare

Federated Shakespeare.

fedjax.datasets.stackoverflow

Federated stackoverflow.

CIFAR-100

Federated cifar100.

fedjax.datasets.cifar100.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.cifar100.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed cifar100 splits.

Features:

  • x: [N, 32, 32, 3] uint8 pixels.

  • y: [N] int32 labels in the range [0, 100).

Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:

import functools
from fedjax.datasets import cifar100
# Load partially preprocessed splits.
train, test = cifar100.load_data()
# Preprocessing for training.
train_for_train = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=True))
# Preprocessing for eval.
train_for_eval = train.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))
test = test.preprocess_batch(
    functools.partial(preprocess_batch, is_train=False))

Features after this preprocessing:

  • x: [N, 32, 32, 3] float32 preprocessed pixels.

  • y: [N] int32 labels in the range [0, 100).

Alternatively, you can apply the same preprocessing as TensorFlow Federated following tff.simulation.baselines.cifar100.create_image_classification_task. For example,:

from fedjax.datasets import cifar100
train, test = cifar100.load_data()
train = train.preprocess_batch(preprocess_batch_tff)
test = test.preprocess_batch(preprocess_batch_tff)

Features after this preprocessing:

  • x: [N, 24, 24, 3] float32 preprocessed pixels.

  • y: [N] int32 labels in the range [0, 100).

Note: preprocess_batch and preprocess_batch_tff are just convenience wrappers around preprocess_image() and preprocess_image_tff(), respectively, for use with fedjax.FederatedData.preprocess_batch().

Parameters:
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

A (train, test) tuple of federated data.

fedjax.datasets.cifar100.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a cifar100 split.

Features:

  • image: [N, 32, 32, 3] uint8 pixels.

  • coarse_label: [N] int64 coarse labels in the range [0, 20).

  • label: [N] int64 labels in the range [0, 100).

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

fedjax.datasets.cifar100.preprocess_image(image, is_train)[source]

Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.

Preprocessing procedure and values taken from pytorch-cifar.

Parameters:
  • image (ndarray) – [N, 32, 32, 3] uint8 pixels.

  • is_train (bool) – Whether we are preprocessing for training or eval.

Return type:

ndarray

Returns:

Processed [N, 32, 32, 3] float32 pixels.

EMNIST

Federated EMNIST.

fedjax.datasets.emnist.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.emnist.domain_id(client_id)[source]

Returns domain id for client id.

Domain ids are based on the NIST data source, where examples were collected from two sources: Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS). For more details, see the NIST documentation.

Parameters:

client_id (bytes) – Client id of the format [16-byte hex hash]:f[4-digit integer]_[2-digit integer] or f[4-digit integer]_[2-digit integer].

Return type:

int

Returns:

Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).

fedjax.datasets.emnist.load_data(only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads processed EMNIST train and test splits.

Features:

  • x: [N, 28, 28, 1] float32 flipped image pixels.

  • y: [N] int32 classification label.

  • domain_id: [N] int32 domain id (see domain_id()).

Parameters:
  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

Train and test splits as FederatedData.

fedjax.datasets.emnist.load_split(split, only_digits=False, mode='sqlite', cache_dir=None)[source]

Loads an unprocessed federated emnist split.

Features:

  • pixels: [N, 28, 28] float32 image pixels.

  • label: [N] int32 classification label.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • only_digits (bool) – Whether to only load the digits data.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

Shakespeare

Federated Shakespeare.

fedjax.datasets.shakespeare.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.shakespeare.load_data(sequence_length=80, mode='sqlite', cache_dir=None)[source]

Loads preprocessed shakespeare splits.

Preprocessing is done using fedjax.FederatedData.preprocess_client() and preprocess_client().

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

Parameters:
  • sequence_length (int) – The fixed sequence length after preprocessing.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData]

Returns:

A (train, held_out, test) tuple of federated data.

fedjax.datasets.shakespeare.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a shakespeare split.

Features:

  • snippets: [N] bytes array of snippet text.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

fedjax.datasets.shakespeare.preprocess_client(client_id, examples, sequence_length)[source]

Turns snippets into sequences of integer labels.

Features (M below is possibly different from N in load_split):

  • x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)

  • y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)

All snippets in a client dataset are first joined into a single sequence (with BOS/EOS added), and then split into pairs of sequence_length chunks for language model training. For example, with sequence_length=3, [b’ABCD’, b’E’] becomes:

Input sequences:  [[BOS, A, B], [C, D, EOS],   [BOS, E, PAD]]
Output seqeunces: [[A, B, C],   [D, EOS, BOS], [E, EOS, PAD]]

Note: This is not equivalent to the TensorFlow Federated text generation tutorial (The processing logic there loses ~1/sequence_length portion of the tokens).

Parameters:
  • client_id (bytes) – Not used.

  • examples (Mapping[str, ndarray]) – Unprocessed examples (e.g. from load_split()).

  • sequence_length (int) – The fixed sequence length after preprocessing.

Return type:

Mapping[str, ndarray]

Returns:

Processed examples.

Stack Overflow

Federated stackoverflow.

fedjax.datasets.stackoverflow.cite()[source]

Returns BibTeX citation for the dataset.

fedjax.datasets.stackoverflow.load_data(mode='sqlite', cache_dir=None)[source]

Loads partially preprocessed stackoverflow splits.

Features:

  • domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).

  • tokens: [N] bytes array. Space separated list of tokens.

To convert tokens into padded/truncated integer labels, use a StackoverflowTokenizer. For example,:

from fedjax.core.datasets import stackoverflow
# Load partially preprocessed splits.
train, held_out, test = stackoverflow.load_data()
# Apply tokenizer during batching.
tokenizer = stackoverflow.StackoverflowTokenizer()
train_max_length, eval_max_length = 20, 30
train_for_train = train.preprocess_batch(
    tokenizer.as_preprocess_batch(train_max_length))
train_for_eval = train.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
held_out = held_out.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))
test = test.preprocess_batch(
    tokenizer.as_preprocess_batch(eval_max_length))

Features after tokenization:

  • domain_id: Same as before.

  • x: [N, max_length] int32 array of padded/truncated input labels.

  • y: [N, max_length] int32 array of padded/truncated output labels.

Parameters:
  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData, FederatedData]

Returns:

A (train, held_out, test) tuple of federated data.

fedjax.datasets.stackoverflow.load_split(split, mode='sqlite', cache_dir=None)[source]

Loads a stackoverflow split.

All bytes arrays are stored with dtype=np.object.

Features:

  • creation_date: [N] bytes array. Textual timestamp, e.g. b’2018-02-28 19:06:18.34 UTC’.

  • title: [N] bytes array. The title of a post.

  • score: [N] int64 array. The score of a post.

  • tags: [N] bytes array. ‘|’ separated list of tags, e.g. b’mysql|join’.

  • tokens: [N] bytes array. Space separated list of tokens.

  • type: [N] bytes array. Either b’question’ or b’answer’.

Parameters:
  • split (str) – Name of the split. One of SPLITS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

FederatedData

Returns:

FederatedData.

class fedjax.datasets.stackoverflow.StackoverflowTokenizer(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Tokenizer for the tokens feature in stackoverflow.

See load_data() for examples.

__init__(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]

Initializes a tokenizer.

Parameters:
  • vocab (Optional[List[str]]) – Optional vocabulary. If specified, default_vocab_size is ignored. If None, default_vocab_size is used to load the standard vocabulary. This vocabulary should NOT have special tokens PAD, EOS, BOS, and OOV. The special tokens are added and handled automatically by the tokenizer. The preprocessed examples will have vocabulary size len(vocab) + 3 + num_oov_buckets.

  • default_vocab_size (Optional[int]) – Number of words in the default vocabulary. This is only used when vocab is not specified. The preprocessed examples will have vocabulary size default_vocab_size + 3 + num_oov_buckets with 3 special labels: 0 (PAD), 1 (BOS), 2 (EOS), and num_oov_buckets OOV labels starting at default_vocab_size + 3.

  • num_oov_buckets (int) – Number of out of vocabulary buckets.

fedjax.models

Model implementations for FedJAX experimental API.

fedjax.models.emnist

EMNIST models.

fedjax.models.shakespeare

Shakespeare recurrent models.

fedjax.models.stackoverflow

Stack Overflow recurrent models.

fedjax.models.toy_regression

Toy regression models.

EMNIST

EMNIST models.

class fedjax.models.emnist.ConvDropoutModule(*args, **kwargs)[source]

Bases: Module

Custom haiku module for CNN with dropout.

This must be defined as a custom hk.Module because only a single positional argument is allowed when using hk.Sequential.

class fedjax.models.emnist.Dropout(rate=0.5)[source]

Bases: Module

Dropout haiku module.

fedjax.models.emnist.create_conv_model(only_digits=False)[source]

Creates EMNIST CNN model with dropout with haiku.

Matches the model used in:

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295

Parameters:

only_digits (bool) – Whether to use only digit classes [0-9] or include lower and upper case characters for a total of 62 classes.

Return type:

Model

Returns:

Model

fedjax.models.emnist.create_dense_model(only_digits=False, hidden_units=200)[source]

Creates EMNIST dense net with haiku.

Return type:

Model

fedjax.models.emnist.create_logistic_model(only_digits=False)[source]

Creates EMNIST logistic model with haiku.

Return type:

Model

fedjax.models.emnist.create_stax_dense_model(only_digits=False, hidden_units=200)[source]

Creates EMNIST dense net with stax.

Return type:

Model

Shakespeare

Shakespeare recurrent models.

fedjax.models.shakespeare.create_lstm_model(vocab_size=86, embed_size=8, lstm_hidden_size=256, lstm_num_layers=2)[source]

Creates LSTM language model.

Character-level LSTM for Shakespeare language model. Defaults to the model used in:

Communication-Efficient Learning of Deep Networks from Decentralized Data

H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629

Parameters:
  • vocab_size (int) – The number of possible output characters. This does not include special tokens like PAD, BOS, EOS, or OOV.

  • embed_size (int) – Embedding size for each character.

  • lstm_hidden_size (int) – Hidden size for LSTM cells.

  • lstm_num_layers (int) – Number of LSTM layers.

Return type:

Model

Returns:

Model.

Stack Overflow

Stack Overflow recurrent models.

fedjax.models.stackoverflow.create_lstm_model(vocab_size=10000, embed_size=96, lstm_hidden_size=670, lstm_num_layers=1, share_input_output_embeddings=False, expected_length=None)[source]

Creates LSTM language model.

Word-level language model for Stack Overflow. Defaults to the model used in:

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295

Parameters:
  • vocab_size (int) – The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV.

  • embed_size (int) – Embedding size for each word.

  • lstm_hidden_size (int) – Hidden size for LSTM cells.

  • lstm_num_layers (int) – Number of LSTM layers.

  • share_input_output_embeddings (bool) – Whether to share the input embeddings with the output logits.

  • expected_length (Optional[float]) – Expected average sentence length used to scale the training loss down by 1. / expected_length. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling.

Return type:

Model

Returns:

Model.

Toy Regression

Toy regression models.

fedjax.models.toy_regression.create_regression_model()[source]

Creates toy regression model.

Matches the model used in:

Communication-Efficient Agnostic Federated Averaging

Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748

Return type:

Model

Returns:

Model

fedjax.training

FedJAX training utilities.

Federated experiment

fedjax.training.run_federated_experiment(algorithm, init_state, client_sampler, config, periodic_eval_fn_map=None, final_eval_fn_map=None)[source]

Runs the training loop of a federated algorithm experiment.

Parameters:
  • algorithm (FederatedAlgorithm) – Federated algorithm to use.

  • init_state (Any) – Initial server state.

  • client_sampler (ClientSampler) – Sampler for training clients.

  • config (FederatedExperimentConfig) – FederatedExperimentConfig configurations.

  • periodic_eval_fn_map (Optional[Mapping[str, Any]]) – Mapping of name to evaluation functions that are run repeatedly over multiple federated training rounds. The frequency is defined in _FederatedExperimentConfig.eval_frequency.

  • final_eval_fn_map (Optional[Mapping[str, EvaluationFn]]) – Mapping of name to evaluation functions that are run at the very end of federated training. Typically, full test evaluation functions will be set here.

Return type:

Any

Returns:

Final state of the input federated algortihm after training.

class fedjax.training.FederatedExperimentConfig(root_dir: str, num_rounds: int, checkpoint_frequency: int = 0, num_checkpoints_to_keep: int = 1, eval_frequency: int = 0)[source]

Common configurations of a federated experiment.

Attribues:

root_dir: Root directory for experiment outputs (e.g. metrics). num_rounds: Number of federated training rounds. checkpoint_frequency: Checkpoint frequency in rounds. If <= 0, no checkpointing is done. num_checkpoints_to_keep: Maximum number of checkpoints to keep. eval_frequency: Evaluation frequency in rounds. If <= 0, no evaluation is done.

class fedjax.training.EvaluationFn[source]

Evaluation function that are only fed state at every call.

Typically used for full evaluation or evaluation on sampled clients from a test set.

abstract __call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

class fedjax.training.ModelFullEvaluationFn(fd, model, batch_hparams)[source]

Bases: EvaluationFn

Evaluation on an entire federated dataset using the centralized model.

__call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

__init__(fd, model, batch_hparams)[source]
class fedjax.training.ModelSampleClientsEvaluationFn(client_sampler, model, batch_hparams)[source]

Bases: EvaluationFn

Evaluation on sampled clients using the centralized model.

The state to be evaluated must contain a params field.

__call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

__init__(client_sampler, model, batch_hparams)[source]
class fedjax.training.TrainClientsEvaluationFn[source]

Evaluation function that are fed training clients at every call.

Typically used for evaluation on the training clients used in a step.

abstract __call__(state, round_num, train_clients)[source]

Runs evaluation.

Return type:

Mapping[str, Array]

class fedjax.training.ModelTrainClientsEvaluationFn(model, batch_hparams)[source]

Bases: TrainClientsEvaluationFn

Evaluation on training clients using the centralized model.

The state to be evaluated must contain a params field.

__call__(state, round_num, train_clients)[source]

Runs evaluation.

Return type:

Mapping[str, Array]

__init__(model, batch_hparams)[source]
fedjax.training.set_tf_cpu_only()[source]

Restricts TensorFlow device visibility to only CPU.

TensorFlow is only used for data loading, so we prevent it from allocating GPU/TPU memory.

fedjax.training.load_latest_checkpoint(root_dir)[source]

Loads latest checkpoint and round number.

Return type:

Optional[Tuple[Any, int]]

fedjax.training.save_checkpoint(root_dir, state, round_num=0, keep=1)[source]

Saves checkpoint and cleans up old checkpoints.

class fedjax.training.Logger(root_dir=None)[source]

Class to encapsulate tf.summary.SummaryWriter logging logic.

__init__(root_dir=None)[source]

Initializes summary writers and log directory.

log(writer_name, metric_name, metric_value, round_num)[source]

Records metric using specified summary writer.

Logs at INFO verbosity. Also, if root_dir is set and metric_value is: - a scalar value, convertible to a float32 Tensor, writes scalar summary - a vector, convertible to a float32 Tensor, writes histogram summary

Parameters:
  • writer_name (str) – Name of summary writer.

  • metric_name (str) – Name of metric to log.

  • metric_value (Any) – Value of metric to log.

  • round_num (int) – Round number to log.

Tasks

Registry of standard tasks.

Each task is represented as a (train federated data, test federated data, model) tuple.

training.ALL_TASKS = ('EMNIST_CONV', 'EMNIST_LOGISTIC', 'EMNIST_DENSE', 'SHAKESPEARE_CHARACTER', 'STACKOVERFLOW_WORD', 'CIFAR100_LOGISTIC')
fedjax.training.get_task(name, mode='sqlite', cache_dir=None)[source]

Gets a standard task.

Parameters:
  • name (str) – Name of the task to get. Must be one of fedjax.training.ALL_TASKS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData, Model]

Returns:

(train federated data, test federated data, model) tuple.

Structured flags

Structured flags commonly used in experiment binaries.

Structured flags are often used to construct complex structures via multiple simple flags (e.g. an optimizer can be created by controlling learning rate and other hyper parameters).

class fedjax.training.structured_flags.BatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs BatchHParams from flags.

class fedjax.training.structured_flags.FederatedExperimentConfigFlags(name=None)[source]

Constructs FederatedExperimentConfig from flags.

class fedjax.training.structured_flags.NamedFlags(name)[source]

A group of flags with an optional named prefix.

class fedjax.training.structured_flags.OptimizerFlags(name=None, default_optimizer='sgd')[source]

Constructs a fedjax.Optimizer from flags.

get()[source]

Gets the specified optimizer.

Return type:

Optimizer

class fedjax.training.structured_flags.PaddedBatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs PaddedBatchHParams from flags.

class fedjax.training.structured_flags.ShuffleRepeatBatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs ShuffleRepeatBatchHParams from flags.

class fedjax.training.structured_flags.TaskFlags(name=None)[source]

Constructs a standard task tuple from flags.