fedjax core

FedJAX API.

Subpackages

Federated algorithm

fedjax.FederatedAlgorithm

Container for all federated algorithms.

Federated data

fedjax.FederatedData

FederatedData interface for providing access to a federated dataset.

fedjax.SubsetFederatedData

A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.

fedjax.SQLiteFederatedData

Federated dataset backed by SQLite.

fedjax.InMemoryFederatedData

A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.

fedjax.FederatedDataBuilder

FederatedDataBuilder interface.

fedjax.SQLiteFederatedDataBuilder

Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.

fedjax.ClientPreprocessor

A chain of preprocessing functions on all examples of a client dataset.

fedjax.shuffle_repeat_batch_federated_data

Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.

fedjax.padded_batch_federated_data

Padded batch all client datasets, useful for evaluation on the entire federated dataset.

Client dataset

fedjax.ClientDataset

In memory client dataset backed by numpy ndarrays.

fedjax.BatchPreprocessor

A chain of preprocessing functions on batched examples.

fedjax.buffered_shuffle_batch_client_datasets

Shuffles and batches examples from multiple client datasets.

fedjax.padded_batch_client_datasets

Batches examples from multiple client datasets.

For each client

fedjax.for_each_client

Creates a function which maps over clients.

fedjax.for_each_client_backend

A context manager for switching to a given ForEachClientBackend in the current thread.

fedjax.set_for_each_client_backend

Sets the for_each_client backend for the current thread.

Model

fedjax.Model

Container class for models.

fedjax.create_model_from_haiku

Creates Model after applying defaults and haiku specific preprocessing.

fedjax.create_model_from_stax

Creates Model after applying defaults and stax specific preprocessing.

fedjax.evaluate_model

Evaluates model for multiple batches and returns final results.

fedjax.model_grad

A standard gradient function derived from a model and an optional regularizer.

fedjax.model_per_example_loss

Convenience function for constructing a per-example loss function from a model.

fedjax.evaluate_average_loss

Evaluates the average per example loss over multiple batches.

fedjax.ModelEvaluator

Evaluates model for each client dataset, either using global params, or per client params.

fedjax.AverageLossEvaluator

Evaluates average loss for each client dataset, either using global params, or per client params.

fedjax.grad

A standard gradient function derived from per-example loss and an optional regularizer.


class fedjax.FederatedAlgorithm(init, apply)[source]

Container for all federated algorithms.

FederatedAlgorithm defines the required methods that need to be implemented by all custom federated algorithms. Defining federated algorithms in this structure will allow implementations to work seamlessly with the convenience methods in the fedjax.training API, like checkpointing.

Example toy implementation:

# Federated algorithm that just counts total number of points across clients
# across rounds.

def count_federated_algorithm():

  def init(init_count):
    return {'count': init_count}

  def apply(state, clients):
    count = 0
    client_diagnostics = {}
    # Count sizes across clients.
    for client_id, client_dataset, _ in clients:
      # Summation across clients in one round.
      client_count = len(client_dataset)
      count += client_count
      client_diagnostics[client_id] = client_count
    # Summation across rounds.
    state = {'count': state['count'] + count}
    return state, client_diagnostics

  return FederatedAlgorithm(init, apply)

rng = jax.random.PRNGKey(0)  # Unused.
all_clients = [
    [
      (b'cid0', ClientDataset({'x': jnp.array([1, 2, 1, 2, 3, 4])}), rng),
      (b'cid1', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5])}), rng),
      (b'cid2', ClientDataset({'x': jnp.array([1, 1, 2])}), rng),
    ],
    [
      (b'cid3', ClientDataset({'x': jnp.array([1, 2, 3, 4])}), rng),
      (b'cid4', ClientDataset({'x': jnp.array([1, 1, 2, 1, 2, 3])}), rng),
      (b'cid5', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5, 6, 7])}), rng),
    ],
]
algorithm = count_federated_algorithm()
state = algorithm.init(0)
for round_num in range(2):
  state, client_diagnostics = algorithm.apply(state, all_clients[round_num])
  print(round_num, state)
  print(round_num, client_diagnostics)
# 0 {'count': 14}
# 0 {b'cid0': 6, b'cid1': 5, b'cid2': 3}
# 1 {'count': 31}
# 1 {b'cid3': 4, b'cid4': 6, b'cid5': 7}
init

Initializes the ServerState. Typically, the input to this method will be the initial model Params. This should only be run once at the beginning of training.

Type:

Callable[[…], Any]

apply

Completes one round of federated training given an input ServerState and a sequence of tuples of client identifier, client dataset, and client rng. The output will be a new, updated ServerState and accumulated per step results keyed by client identifier (e.g. train metrics).

Type:

Callable[[Any, Sequence[Tuple[bytes, fedjax.core.client_datasets.ClientDataset, jax.Array]]], Tuple[Any, Mapping[bytes, Any]]]


class fedjax.FederatedData[source]

FederatedData interface for providing access to a federated dataset.

A FederatedData object serves as a mapping from client ids to client datasets and client metadata.

Access methods with better I/O efficiency

For large federated datasets, it is not feasible to load all client datasets into memory at once (whereas loading a single client dataset is assumed to be feasible). Different implementations exist for different on disk storage formats. Since sequential read is much faster than random read for most storage technologies, FederatedData provides two types of methods for accessing client datasets,

  1. clients() and shuffled_clients() are sequential read friendly, and thus recommended whenever appropriate.

  2. get_clients() requires random read, but prefetching is possible. This should be preferred over get_client().

  3. get_client() is usually the slowest way of accessing client datasets, and is mostly intended for interactive exploration of a small number of clients.

Preprocessing

ClientDataset produced by FederatedData can hold a BatchPreprocessor, customizable via preprocess_batch(). Additionally, another “client” level ClientPreprocessor, customizable via preprocess_client(), can be used to apply transformations on examples from the entire client dataset before a ClientDataset is constructed.

abstract client_ids()[source]

Returns an iterator of client ids as bytes.

There is no requirement on the order of iteration.

Return type:

Iterator[bytes]

abstract client_size(client_id)[source]

Returns the number of examples in a client dataset.

Return type:

int

abstract client_sizes()[source]

Returns an iterator of all (client id, client size) pairs.

This is often more efficient than making multiple client_size() calls. There is no requirement on the order of iteration.

Return type:

Iterator[Tuple[bytes, int]]

abstract clients()[source]

Iterates over clients in a deterministic order.

Implementation can choose whatever order that makes iteration efficient.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

abstract get_client(client_id)[source]

Gets one single client dataset.

Prefer clients(), shuffled_clients(), or get_clients() when possible.

Parameters:

client_id (bytes) – Client id to load.

Return type:

ClientDataset

Returns:

The corresponding ClientDataset.

abstract get_clients(client_ids)[source]

Gets multiple clients in order with one call.

Clients are returned in the order of client_ids.

Parameters:

client_ids (Iterable[bytes]) – Client ids to load.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

Returns:

Iterator.

abstract num_clients()[source]

Returns the number of clients.

If it is too expensive or otherwise impossible to obtain the result, an implementation may raise an exception.

Return type:

int

abstract preprocess_batch(fn)[source]

Registers a preprocessing function to be called after batching in ClientDatasets.

Return type:

FederatedData

abstract preprocess_client(fn)[source]

Registers a preprocessing function to be called on all examples of a client before passing them to construct a ClientDataset.

Return type:

FederatedData

abstract shuffled_clients(buffer_size, seed=None)[source]

Iterates over clients with a repeated buffered shuffling.

Shuffling should use a buffer size of at least buffer_size clients. The iteration should repeat forever, with usually a different order in each pass.

Parameters:
  • buffer_size (int) – Buffer size for shuffling.

  • seed (Optional[int]) – Optional random number generator seed.

Return type:

Iterator[Tuple[bytes, ClientDataset]]

Returns:

Iterator.

abstract slice(start=None, stop=None)[source]

Returns a new FederatedData restricted to client ids in the given range.

The returned FederatedData includes clients whose ids are,

  • Greater than or equal to start when start is not None;

  • Less than stop when stop is not None.

Parameters:
  • start (Optional[bytes]) – Start of client id range.

  • stop (Optional[bytes]) – Stop of client id range.

Return type:

FederatedData

Returns:

FederatedData.

class fedjax.SubsetFederatedData(base, client_ids, validate=True)[source]

Bases: FederatedData

A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.

This is useful when we wish to create a smaller FederatedData out of arbitrary client ids, where slicing is not possible.

__init__(base, client_ids, validate=True)[source]

Initializes the subset federated dataset.

Parameters:
  • base (FederatedData) – Base concrete FederatedData.

  • client_ids (Iterable[bytes]) – Client ids to include in the subset. All client ids must be in base.client_ids(), otherwise the behavior of SubsetFederatedData is undefined when validate=False.

  • validate – Whether to validate client ids.

class fedjax.SQLiteFederatedData(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Bases: FederatedData

Federated dataset backed by SQLite.

The SQLite database should contain a table named “federated_data” created with the following command:

CREATE TABLE federated_data (
  client_id BLOB NOT NULL PRIMARY KEY,
  data BLOB NOT NULL,
  num_examples INTEGER NOT NULL
);

where,

  • client_id is the bytes client id.

  • data is the serialized client dataset examples.

  • num_examples is the number of examples in the client dataset.

By default we use zlib compressed msgpack blobs for data (see decompress_and_deserialize()).

__init__(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
static new(path, parse_examples=<function decompress_and_deserialize>)[source]

Opens a federated dataset stored as an SQLite3 database.

Parameters:
  • path (str) – Path to the SQLite database file.

  • parse_examples (Callable[[bytes], Mapping[str, ndarray]]) – Function for deserializing client dataset examples.

Return type:

SQLiteFederatedData

Returns:

SQLite3DataSource.

class fedjax.InMemoryFederatedData(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Bases: FederatedData

A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.

This is useful when we wish to create a smaller FederatedData that fits in memory. Here is a simple example to create a fedjax.InMemoryFederatedData,

client_a_data = {
    'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
    'y': np.array([7, 8])
}
client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])}
client_to_data_mapping = {'a': client_a_data, 'b': client_b_data}

fedjax.InMemoryFederatedData(client_to_data_mapping)
Returns:

A fedjax.InMemoryDataset corresponding to client_to_data_mapping.

__init__(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]

Initializes the in memory federated dataset.

Data of each client is a mapping from feature names to numpy arrays. For example, for emnist image classification, {‘x’: X, ‘y’: y}, where X is a matrix of shape (num_data_points, 28, 28) and y is a matrix of shape (num_data_points).

Parameters:
  • client_to_data_mapping (Mapping[bytes, Mapping[str, ndarray]]) – A mapping from client_id to data of each client.

  • preprocess_client (ClientPreprocessor) – federated_data.ClientPreprocessor to preprocess each client data.

  • preprocess_batch (BatchPreprocessor) – client_datasets.BatchPreprocessor to preprocess batch of data.

class fedjax.FederatedDataBuilder[source]

FederatedDataBuilder interface.

To be implemented as a context manager for building file formats from pairs of client IDs and client NumPy examples.

It is relevant to note that the add method below does not specify any raised exceptions. One could imagine some formats where add can fail in some way: out-of-order or duplicate inputs, remote files and network failures, individual entries too big for a format, etc. In order to address this we let implementations throw whatever they see relevant and fit to their particular use cases. The same is relevant when it comes to the __init__, __enter__, and __exit__ methods, where implementations are left with the responsibility of raising exceptions as they see fit to their particular use cases. For example, if an invalid file path is passed, or there were any issues finalizing the builder, etc.

Eg of end behavior:

with FederatedDataBuilder(path) as builder:
  builder.add(b'k1', np.array([b'v1'], dtype=np.object))
  builder.add(b'k2', np.array([b'v2'], dtype=np.object))
abstract add_many(client_ids_examples)[source]

Bulk adds multiple client IDs and client NumPy examples pairs to file format.

Parameters:

client_ids_examples (Iterable[Tuple[bytes, Mapping[str, ndarray]]]) – Iterable of tuples of client id and examples.

class fedjax.SQLiteFederatedDataBuilder(path)[source]

Bases: FederatedDataBuilder

Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.

__init__(path)[source]

Initializes SQLiteBuilder by opening a connection and setting up the database with columns.

Parameters:

path (str) – Path of file to write to (e.g. /tmp/sqlite_federated_data.sqlite).

class fedjax.ClientPreprocessor(fns=())[source]

A chain of preprocessing functions on all examples of a client dataset.

This is very similar to fedjax.BatchPreprocessor, with the main difference being that ClientPreprocessor also takes client_id as input.

See the discussion in fedjax.BatchPreprocessor regarding when to use which.

__call__(client_id, examples)[source]

Call self as a function.

Return type:

Mapping[str, ndarray]

__init__(fns=())[source]
append(fn)[source]

Creates a new ClientPreprocessor with fn added to the end.

Return type:

ClientPreprocessor

fedjax.shuffle_repeat_batch_federated_data(fd, batch_size, client_buffer_size, example_buffer_size, seed=None)[source]

Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.

Shuffling is done using two levels of buffered shuffling, first at the client level, then at the example level.

This produces an infinite stream of batches. itertools.islice() can be used to cap the number of batches, if so desired.

Parameters:
  • fd (FederatedData) – Federated dataset.

  • batch_size (int) – Desired batch size.

  • client_buffer_size (int) – Buffer size for client level shuffling.

  • example_buffer_size (int) – Buffer size for example level shuffling.

  • seed (Optional[int]) – Optional RNG seed.

Yields:

Batches of preprocessed examples.

Return type:

Iterator[Mapping[str, ndarray]]

fedjax.padded_batch_federated_data(fd, hparams=None, **kwargs)[source]

Padded batch all client datasets, useful for evaluation on the entire federated dataset.

Parameters:
Yields:

Batches of preprocessed examples.

Return type:

Iterator[Mapping[str, ndarray]]

class fedjax.RepeatableIterator(base)[source]

Repeats a base iterable after the end of the first pass of iteration.

Because this is a stateful object, it is not thread safe, and all usual caveats about accessing the same iterator at different locations apply. For example, if we make two map calls to the same RepeatableIterator, we must make sure we do not interleave next() calls on these. For example, the following is safe because we finish iterating on m1 before starting to iterate on m2.,

it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
# We finish iterating on m1 before starting to iterate on m2.
print(list(m1), list(m2))
# [1, 2, 3, 4] [0, 1, 4, 9]

Whereas interleaved access leads to confusing results,

it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
print(next(m1), next(m2))
# 1 1
print(next(m1), next(m2))
# 3 9
print(next(m1), next(m2))
# StopIteration!

In the first pass of iteration, values fetched from the base iterator will be copied onto an internal buffer (except for a few builtin containers where copying is unnecessary). When each pass of iteration finishes (i.e. when __next__() raises StopIteration), the iterator resets itself to the start of the buffer, thus allowing a subsequent pass of repeated iteration.

In most cases, if repeated iterations are required, it is sufficient to simply copy values from an iterator into a list. However, sometimes an iterator produces values via potentially expensive I/O operations (e.g. loading client datasets), RepeatableIterator can interleave I/O and JAX compute to decrease accelerator idle time in this case.


Preprocessing and batching operations over client datasets.

Column based representation

The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).

We use a column based representation when loading a dataset into memory.

  • Each column is a NumPy array x of rank at least 1, where x[i, ...] is the value of this feature for the i-th example.

  • The complete set of examples is a dict-like object, from str feature names, to the corresponding column values.

Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch. In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.

Preprocessor

Preprocessing on a batch of examples can be easily done via a chain of functions. A Preprocessor object holds the chain of functions, and applies the transformation on a batch of examples.

ClientDataset: examples + preprocessor

A ClientDataset is simply some examples in the column based representation, accompanied by a Preprocessor. Its batch() method produces batches of examples in a sequential order, suitable for evaluation. Its shuffle_repeat_batch() method adds shuffling and repeating, making it suitable for training.

class fedjax.ClientDataset(raw_examples, preprocessor=BatchPreprocessor(()))[source]

In memory client dataset backed by numpy ndarrays.

Custom preprocessing on batches can be added via a preprocessor. A ClientDataset is stored as the unpreprocessed raw_examples, along with its preprocessor.

This is only intended for efficient access to small datasets that fit in memory.

__getitem__(index)[source]

Returns a new ClientDataset with sliced raw examples.

Return type:

ClientDataset

__init__(raw_examples, preprocessor=BatchPreprocessor(()))[source]
__len__()[source]

Returns the number of raw examples in this dataset.

Return type:

int

all_examples()[source]

Returns the result of feeding all raw examples through the preprocessor.

This is mostly intended for interactive exploration of a small subset of a client dataset. For example, to see the first 4 examples in a client dataset,

dataset = ClientDataset(my_raw_examples, my_preprocessor)
dataset[:4].all_examples()
Return type:

Mapping[str, ndarray]

Returns:

Preprocessed examples from all the raw examples in this client dataset.

batch(hparams=None, **kwargs)[source]

Produces preprocessed batches in a fixed sequential order.

The final batch may contain fewer than batch_size examples. If used directly, that may result in a large number of JIT recompilations. Therefore we recommended using padded_batch instead in most scenarios.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code if you have to use batch (prefer padded_batch() if possible). Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.batch(batch_size=2)
    # Overrides the default drop_remainder value.
    client_dataset.batch(hparams, drop_remainder=True)
    
Parameters:
  • hparams (Optional[BatchHParams]) – Batching hyperparameters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

padded_batch(hparams=None, **kwargs)[source]

Produces preprocessed padded batches in a fixed sequential order.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.padded_batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.padded_batch(batch_size=2)
    # Overrides the default num_batch_size_buckets value.
    client_dataset.padded_batch(hparams, num_batch_size_buckets=2)
    

When the number of examples in the dataset is not a multiple of batch_size, the final batch may be smaller than batch_size. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled by num_batch_size_buckets.

All batches contain an extra bool feature keyed by EXAMPLE_MASK_KEY. batch[EXAMPLE_MASK_KEY][i] tells us whether the i-th example in this batch is an actual example (batch[EXAMPLE_MASK_KEY][i] == True), or a padding example (batch[EXAMPLE_MASK_KEY][i] == False).

We repeatedly halve the batch size up to num_batch_size_buckets-1 times, until we find the smallest one that is also >= the size of the final batch. Therefore if batch_size < 2^num_batch_size_buckets, fewer bucket sizes will be actually used.

Parameters:
  • hparams (Optional[PaddedBatchHParams]) – Batching hyperparameters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

shuffle_repeat_batch(hparams=None, **kwargs)[source]

Produces preprocessed batches in a shuffled and repeated order.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(client_dataset, hparams):
      for batch in client_dataset.shuffle_repeat_batch(hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    client_dataset.shuffle_repeat_batch(batch_size=2)
    # Overrides the default num_epochs value.
    client_dataset.shuffle_repeat_batch(hparams, num_epochs=2)
    

Shuffling is done without replacement, therefore for a dataset of N examples, the first ceil(N/batch_size) batches are guarranteed to cover the entire dataset.

By default the iteration stops after the first epoch. The number of batches produced from the iteration can be controlled by the (num_epochs, num_steps, drop_remainder) combination:

  • If both num_epochs and num_steps are None, the shuffle-repeat process continues forever.

  • If num_epochs is set and num_steps is None, as few batches as needed to go over the dataset this many passes are produced. Further,

    • If drop_remainder is False (the default), the final batch is filled with additionally sampled examples to contain batch_size examples.

    • If drop_remainder is True, the final batch is dropped if it contains fewer than batch_size examples. This may result in examples being skipped when num_epochs=1.

  • If num_steps is set and num_steps is None, exactly this many batches are produced. drop_remainder has no effect in this case.

  • If both num_epochs and num_steps are set, the fewer number of batches between the two conditions are produced.

If reproducible iteration order is desired, a fixed seed can be used. When seed is None, repeated iteration over the same object may produce batches in a different order.

Unlike batch() or padded_batch(), batches from shuffle_repeat_batch() always contain exactly batch_size examples. Also unlike TensorFlow, that holds even when drop_remainder=False.

Parameters:
  • hparams (Optional[ShuffleRepeatBatchHParams]) – Batching hyperparamters.

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Return type:

Iterable[Mapping[str, ndarray]]

Returns:

An iterable object that can be iterated over multiple times.

class fedjax.BatchPreprocessor(fns=())[source]

A chain of preprocessing functions on batched examples.

BatchPreprocessor holds a chain of preprocessing functions, and applies them in order on batched examples. Each individual preprocessing function operates over multiple examples, instead of just 1 example. For example,

preprocessor = BatchPreprocessor([
  # Flattens `pixels`.
  lambda x: {**x, 'pixels': x['pixels'].reshape([-1, 28 * 28])},
  # Introduce `binary_label`.
  lambda x: {**x, 'binary_label': x['label'] % 2},
])
fake_emnist = {
  'pixels': np.random.uniform(size=(10, 28, 28)),
  'label': np.random.randint(10, size=(10,))
}
preprocessor(fake_emnist)
# Produces a dict of [10, 28*28] "pixels", [10,] "label" and "binary_label".

Given a BatchPreprocessor, a new BatchPreprocessor can be created with an additional preprocessing function appended to the chain,

# Continuing from the previous example.
new_preprocessor = preprocessor.append(
  lambda x: {**x, 'sum_pixels': np.sum(x['pixels'], axis=1)})
new_preprocessor(fake_emnist)
# Produces a dict of [10, 28*28] "pixels", [10,] "sum_pixels", "label" and
# "binary_label".

The main difference of this preprocessor and fedjax.ClientPreprocessor is that fedjax.ClientPreprocessor also takes client_id as input. Because of the identical representation between batched examples and all examples in a client dataset, certain preprocessing can be done with either BatchPreprocessor or ClientPreprocessor.

Examples of preprocessing possible at either the client dataset level, or the batch level

Such preprocessing is deterministic, and strictly per-example.

  • Casting a feature from int8 to float32.

  • Adding a new feature derived from existing features.

  • Remove a feature (although the better place to do so is at the dataset level).

A simple rule for deciding where to carry out the preprocessing in this case is the following,

  • Does this make batching cheaper (e.g. removing features)? If so, do it at the dataset level.

  • Otherwise, do it at the batch level.

Assuming preprocessing time is linear in the number of examples, preprocessing at the batch level has the benefit of evenly distributing host compute work, which may overlap better with asynchronous JAX compute work on GPU/TPU.

Examples of preprocessing only possible at the batch level

  • Data augmentation (e.g. random cropping).

  • Padding at the batch size dimension.

Examples of preprocessing only possible at the dataset level

  • Those that require knowing the client id.

  • Capping the number of examples.

  • Altering what it means to be an example: e.g. in certain language model setups, sentences are concatenated and then split into equal sized chunks.

__init__(fns=())[source]
append(fn)[source]

Creates a new BatchPreprocessor with fn added to the end.

Return type:

BatchPreprocessor

fedjax.buffered_shuffle_batch_client_datasets(datasets, batch_size, buffer_size, rng)[source]

Shuffles and batches examples from multiple client datasets.

This just makes 1 pass over the examples. To achieve repeated iterations, create an infinite shuffled stream of datasets first (e.g. using buffered_shuffle()).

Parameters:
  • datasets (Iterable[ClientDataset]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.

  • batch_size (int) – Desired batch size.

  • buffer_size (int) – Number of examples to buffer during shuffling.

  • rng (RandomState) – Source of randomness.

Yields:

Batches of examples. For a finite stream of datasets, the final batch might be smaller than batch_size.

Raises:
  • ValueError – If any 2 client datasets have different Preprocessors.

  • ValueError – If any 2 client datasets have different features.

Return type:

Iterator[Mapping[str, ndarray]]

fedjax.padded_batch_client_datasets(datasets, hparams=None, **kwargs)[source]

Batches examples from multiple client datasets.

This is useful when we want to evaluate on the combined dataset consisting of multiple client datasets. Unlike batching each client dataset individually, we can reduce the number of batches smaller than batch_size.

This function can be invoked in 2 ways:

  1. Using a hyperparams object. This is the recommended way in library code. Example:

    def a_library_function(datasets, hparams):
      for batch in padded_batch_client_datasets(datasets, hparams):
        ...
    
  2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,

    padded_batch_client_datasets(datasets, hparams)
    # Overrides the default num_batch_size_buckets value.
    padded_batch_client_datasets(datasets, hparams, num_batch_size_buckets=2)
    
Parameters:
  • datasets (Iterable[ClientDataset]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.

  • hparams (Optional[PaddedBatchHParams]) – Batching hyperparams like those in ClientDataset.padded_batch().

  • **kwargs – Keyword arguments for constructing/overriding hparams.

Yields:

Batches of examples. The final batch might be padded. All batches contain a bool feature keyed by EXAMPLE_MASK_KEY.

Raises:
  • ValueError – If any 2 client datasets have different Preprocessors.

  • ValueError – If any 2 client datasets have different features.

Return type:

Iterator[Mapping[str, ndarray]]


fedjax.for_each_client(client_init, client_step, client_final=<function <lambda>>, with_step_result=False)[source]

Creates a function which maps over clients.

For example, for_each_client could be used to define how to run client updates for each client in a federated training round. Another common use case of for_each_client is to run evaluation per client for a given set of model parameters.

The underlying backend for for_each_client is customizable. For example, if multiple devies are available (e.g. TPU), a jax.pmap() based backend can be used to parallelize across devices. It’s also possible to manually specify which backend to use (for debugging).

The expected usage of for_each_client is as follows:

# Map over clients and count how many points are greater than `limit` for
# each client. Each client also has a different `start` that is specified
# via client input.

def client_init(shared_input, client_input):
  client_step_state = {
      'limit': shared_input['limit'],
      'count': client_input['start']
  }
  return client_step_state

def client_step(client_step_state, batch):
  num = jnp.sum(batch['x'] > client_step_state['limit'])
  client_step_state = {
      'limit': client_step_state['limit'],
      'count': client_step_state['count'] + num
  }
  return client_step_state

def client_final(shared_input, client_step_state):
  del shared_input  # Unused.
  return client_step_state['count']

# Three clients with different data and starting counts.
# clients = [(client_id, client_batches, client_input)]
clients = [
    (b'cid0',
    [{'x': jnp.array([1, 2, 3, 4])}, {'x': jnp.array([1, 2, 3])}],
    {'start': jnp.array(2)}),
    (b'cid1',
    [{'x': jnp.array([1, 2])}, {'x': jnp.array([1, 2, 3, 4, 5])}],
    {'start': jnp.array(0)}),
    (b'cid2',
    [{'x': jnp.array([1])}],
    {'start': jnp.array(1)}),
]
shared_input = {'limit': jnp.array(2)}

func = fedjax.for_each_client(client_init, client_step, client_final)
print(list(func(shared_input, clients)))
# [(b'cid0', 5), (b'cid1', 3), (b'cid2', 1)]

Here’s the same example with per step results.

# We'll also keep track of the `num` per step in our step results.

def client_step_with_result(client_step_state, batch):
  num = jnp.sum(batch['x'] > client_step_state['limit'])
  client_step_state = {
      'limit': client_step_state['limit'],
      'count': client_step_state['count'] + num
  }
  client_step_result = {'num': num}
  return client_step_state, client_step_result

func = fedjax.for_each_client(
    client_init, client_step_with_result, client_final, with_step_result=True)
print(list(func(shared_input, clients)))
# [
#   (b'cid0', 5, [{'num': 2}, {'num': 1}]),
#   (b'cid1', 3, [{'num': 0}, {'num': 3}]),
#   (b'cid2', 1, [{'num': 0}]),
# ]
Parameters:
  • client_init (Callable[[Any, Any], Any]) – Function that initializes the internal intermittent client step state from the share input and per client input. The shared input contains information like the global model parameters that are shared across all clients. The per client input is per client information. The initialized internal client step state is fed as intermittent input and output from client_step and client_final. This client step state usually contains the model parameters and optimizer state for each client that are updated at each client_step. This will be run once for each client.

  • client_step (Union[Callable[[Any, Mapping[str, Array]], Tuple[Any, Any]], Callable[[Any, Mapping[str, Array]], Any]]) – Function that takes the client step state and a batch of examples as input and outputs a (possibly updated) client step state. Optionally, per step results can also be returned as the second element if with_step_result is True. Per step results are usually diagnostics like gradient norm. This will be run for each batch for each client.

  • client_final (Callable[[Any, Any], Any]) – Function that applies the final transformation on the internal client step state to the desired final client output. More meaningful transformations can be done here, like model update clipping. Defaults to just returning the client step state. This will be run once for each client.

  • with_step_result (bool) – Indicates whether client_step returns a pair where the first element is considered the client output and the second element is the client step result.

Returns:

A for each client function that takes shared_input and the per client inputs as tuple (client_id, batched_client_data, client_rng) to map over and returns the outputs per client as specified in client_final along with optional per client per step results.

fedjax.for_each_client_backend(backend)[source]

A context manager for switching to a given ForEachClientBackend in the current thread.

Example:

with for_each_client_backend('pmap'):
  # We will be using the pmap based for_each_client backend within this block.
  pass
# We will be using the default for_each_client backend from now on.
Parameters:

backend (Union[ForEachClientBackend, str, None]) – See set_for_each_client_backend().

Yields:

Nothing.

fedjax.set_for_each_client_backend(backend)[source]

Sets the for_each_client backend for the current thread.

Parameters:

backend (Union[ForEachClientBackend, str, None]) –

One of the following,

  • None: uses the default backend for the current environment.

  • ’debug’: uses the debugging backend.

  • ’jit’: uses the JIT backend.

  • ’pmap’: uses the pmap-based backend.

  • A concrete ForEachClientBackend object.


class fedjax.Model(init, apply_for_train, apply_for_eval, train_loss, eval_metrics)[source]

Container class for models.

Model exists to provide easy access to predefined neural network models. It is meant to contain all the information needed for standard centralized training and evaluation. Non-standard training methods can be built upon the information avaiable in Model along with any additional information (e.g. interpolation can be implemented as a composition of two models along with an interpolation weight).

Works for Haiku and jax.example_libraries.stax.

The expected usage of Model is as follows:

# Training.
step_size = 0.1
rng = jax.random.PRNGKey(0)
params = model.init(rng)

def loss(params, batch, rng):
  preds = model.apply_for_train(params, batch, rng)
  return jnp.sum(model.train_loss(batch, preds))

grad_fn = jax.grad(loss)
for batch in batches:
  rng, use_rng = jax.random.split(rng)
  grads = grad_fn(params, batch, use_rng)
  params = jax.tree_util.tree_map(lambda a, b: a - step_size * b,
                                       params, grads)

# Evaluation.
print(fedjax.evaluate_model(model, params, batches))
# Example output:
# {'loss': 2.3, 'accuracy': 0.2}

The following is an example using Model compositionally as a building block to impelement model interpolation:

def interpolate(model_1, model_2, init_weight):

  @jax.jit
  def init(rng):
    rng_1, rng_2 = jax.random.split(rng)
    params_1 = model_1.init(rng_1)
    params_2 = model_2.init(rng_2)
    return params_1, params_2, init_weight

  @jax.jit
  def apply_for_train(params, input, rng):
    rng_1, rng_2 = jax.random.split(rng)
    params_1, params_2, weight = params
    return (model_1.apply_for_train(params_1, input, rng_1) * weight +
            model_2.apply_for_train(params_1, input, rng_2) * (1 - weight))

  @jax.jit
  def apply_for_eval(params, input):
    params_1, params_2, weight = params
    return (model_1.apply_for_eval(params_1, input) * weight +
            model_2.apply_for_eval(params_2, input) * (1 - weight))

  return fedjax.Model(init,
                      apply_for_train,
                      apply_for_eval,
                      model_1.train_loss,
                      model_1.eval_metrics)

model = interpolate(model_1, model_2, init_weight=0.5)
init

Initialization function that takes a seed PRNGKey and returns a PyTree of initialized parameters (i.e. model weights). These parameters will be passed as input into apply_for_train() and apply_for_eval(). Any trainable weights for a model that are modified in the training loop should be contained inside of these parameters.

Type:

Callable[[jax.Array], Any]

apply_for_train

Function that takes the parameters PyTree, batch of examples, and PRNGKey as inputs and outputs the model predictions for training that are then passed into train_loss(). This considers strategies such as dropout.

Type:

Callable[[Any, Mapping[str, jax.Array], jax.Array], jax.Array]

apply_for_eval

Function that usually takes the parameters PyTree and batch of examples as inputs and outputs the model predictions for evaluation that are then passed to eval_metrics. This is defined separately from apply_for_train() to avoid having to specify inputs like PRNGKey that are not used in evaluation.

Type:

Callable[[Any, Mapping[str, jax.Array]], Union[jax.Array, Mapping[str, jax.Array]]]

train_loss

Loss function for training that takes batch of examples and model output from apply_for_train() as input that outputs per example loss. This will typically called inside a jax.grad() wrapped function to compute gradients.

Type:

Callable[[Mapping[str, jax.Array], jax.Array], jax.Array]

eval_metrics

Ordered mapping of evaluation metric names to Metric. These Metric s are defined for single examples and will be used in evaluate_model()

Type:

Mapping[str, fedjax.core.metrics.Metric]

fedjax.create_model_from_haiku(transformed_forward_pass, sample_batch, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None)[source]

Creates Model after applying defaults and haiku specific preprocessing.

Parameters:
  • transformed_forward_pass (Transformed) – Transformed forward pass from hk.transform()

  • sample_batch (Mapping[str, Array]) – Example input used to determine model parameter shapes.

  • train_loss (Callable[[Mapping[str, Array], Array], Array]) – Loss function for training that outputs per example loss.

  • eval_metrics (Optional[Mapping[str, Metric]]) – Mapping of evaluation metric names to Metric. These metrics are defined for single examples and will be consumed in evaluate_model().

  • train_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for training.

  • eval_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for evaluation.

Return type:

Model

Returns:

Model

fedjax.create_model_from_stax(stax_init, stax_apply, sample_shape, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None, input_key='x')[source]

Creates Model after applying defaults and stax specific preprocessing.

Parameters:
  • stax_init (Callable[…, Any]) – Initialization function returned from stax.serial().

  • stax_apply (Callable[…, Array]) – Model forward_pass pass function returned from stax.serial.

  • sample_shape (Tuple[int, …]) – The expected shape of the input to the model.

  • train_loss (Callable[[Mapping[str, Array], Array], Array]) – Loss function for training that outputs per example loss.

  • eval_metrics (Optional[Mapping[str, Metric]]) – Mapping of evaluation metric names to Metric. These metrics are defined for single examples and will be consumed in evaluate_model().

  • train_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for training.

  • eval_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments passed to model for evaluation.

  • input_key (str) – Key name for the input in batch mapping.

Return type:

Model

Returns:

Model

fedjax.evaluate_model(model, params, batches)[source]

Evaluates model for multiple batches and returns final results.

This is the recommended way to compute evaluation metrics for a given model.

Parameters:
  • model (Model) – Model container.

  • params (Any) – Pytree of model parameters to be evaluated.

  • batches (Iterable[Mapping[str, Array]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch() ).

Return type:

Dict[str, Array]

Returns:

A dictionary of evaluation Metric results.

fedjax.model_grad(model, regularizer=None)[source]

A standard gradient function derived from a model and an optional regularizer.

The scalar loss function being differentiated is simply:

mean(model’s per-example loss) + regularizer term

The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.

Parameters:
  • model (Model) – A Model.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer.

Return type:

Callable[[Any, Mapping[str, Array], Array], Any]

Returns:

A function from (params, batch_example, rng) to gradients.

fedjax.model_per_example_loss(model)[source]

Convenience function for constructing a per-example loss function from a model.

Parameters:

model (Model) – Model.

Return type:

Callable[[Any, Mapping[str, Array], Array], Array]

Returns:

A function from (params, batch_example, rng) to a vector of loss values for each example in the batch.

fedjax.evaluate_average_loss(params, batches, rng, per_example_loss, regularizer=None)[source]

Evaluates the average per example loss over multiple batches.

Parameters:
  • params (Any) – PyTree of model parameters to be evaluated.

  • batches (Iterable[Mapping[str, Array]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch).

  • rng (Array) – Initial PRNGKey for making per_example_loss calls.

  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – Per example loss function.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer function.

Return type:

Array

Returns:

The average per example loss, plus the regularizer term when specified.

class fedjax.ModelEvaluator(model)[source]

Evaluates model for each client dataset, either using global params, or per client params.

To evaluate a Model on a single dataset, use evaluate_model() instead.

__init__(model)[source]
evaluate_global_params(params, clients)[source]

Evaluates batches from each client using global params.

Parameters:
  • params (Any) – Model params to evaluate.

  • clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]]]]) – Client batches.

Yields:

Pairs of the client id and a dictionary of evaluation Metric results for each client.

Return type:

Iterator[Tuple[bytes, Dict[str, Array]]]

evaluate_per_client_params(clients)[source]

Evaluates batches from each client using per client params.

Parameters:

clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Any]]) – Client batches and the per client params.

Yields:

Pairs of the client id and a dictionary of evaluation Metric results for each client.

Return type:

Iterator[Tuple[bytes, Dict[str, Array]]]

class fedjax.AverageLossEvaluator(per_example_loss, regularizer=None)[source]

Evaluates average loss for each client dataset, either using global params, or per client params.

The average loss is defined as the average per example loss, plus the regularizer term when specified. To evaluate average loss on a single dataset, use evaluate_average_loss() instead.

__init__(per_example_loss, regularizer=None)[source]
evaluate_global_params(params, clients)[source]

Evaluates batches from each client using global params.

Parameters:
  • params (Any) – Model params to evaluate.

  • clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Array]]) – Client batches.

Yields:

Pairs of the client id and the client’s average loss.

Return type:

Iterator[Tuple[bytes, Array]]

evaluate_per_client_params(clients)[source]

Evaluates batches from each client using per client params.

Parameters:

clients (Iterable[Tuple[bytes, Iterable[Mapping[str, Array]], Array, Any]]) – Client batches and the per client params.

Yields:

Pairs of the client id and the client’s average loss.

Return type:

Iterator[Tuple[bytes, Array]]

fedjax.grad(per_example_loss, regularizer=None)[source]

A standard gradient function derived from per-example loss and an optional regularizer.

The scalar loss function being differentiated is simply:

mean(per-example loss) + regularizer term

The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.

Parameters:
  • per_example_loss (Callable[[Any, Mapping[str, Array], Array], Array]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch.

  • regularizer (Optional[Callable[[Any], Array]]) – Optional regularizer that only depends on params.

Return type:

Callable[[Any, Mapping[str, Array], Array], Any]

Returns:

A function from (params, batch_example, rng) to gradients.