Federated learning algorithms

Open in Colab

This tutorial introduces algorithms for federated learning in FedJAX. By completing this tutorial, we’ll learn how to write clear and efficient algorithms that follow best practices. This tutorial assumes that we have finished the tutorials on datasets and models.

In order to keep the code pseudo-code-like, we avoid using jax primitives directly while writing algorithms, with the notable exceptions of the jax.randomand jax.tree_util libraries. Since lower-level functions that are described in the model tutorial, such as fedjax.optimizers, model.grad, are all JIT compiled already, the algorithms will still be efficient.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
import jax
import jax.numpy as jnp
import numpy as np

import fedjax
# We only use TensorFlow for datasets, so we restrict it to CPU only to avoid
# issues with certain ops not being available on GPU/TPU.


A federated algorithm trains a machine learning model over decentralized data distributed over several clients. At a high level, the server first randomly initializes the model parameters and other learning components. Then at each round the following happens:

  1. Client selection: The server selects a few clients at each round, typically at random.

  2. The server transmits the model parameters and other necessary components to the selected clients.

  3. Client update: The clients update the model parameters using a subroutine, which typically involves a few epochs of SGD on their local examples.

  4. The clients transmit the updates to the server.

  5. Server aggregation: The server combines the clients’ updates to produce new model parameters.

A pseudo-code for a common federated learning algorithm can be found in Algorithm 1 in Kairouz et al. (2020). Since FedJAX focuses on federated simulation and there is no actual transmission between clients and the server, we only focus on steps 1, 3, and 5, and ignore steps 2 and 4. Before we describe each of the modules, we will first describe how to use algorithms that are implemented in FedJAX.

Federated algorithm overview

We implement federated learning algorithms using the fedjax.FederatedAlgorithm interface. The fedjax.FederatedAlgorithm interface has two functions init and apply. Broadly, our implementation has three parts.

  1. ServerState: This contains all the information available at the server at any given round. It includes model parameters and can also include other parameters that are used during optimization. At every round, a subset of ServerState is passed to the clients for federated learning. ServerState is also used in checkpointing and evaluation. Hence it is crucial that all the parameters that are modified during the course of federated learning are stored as part of the ServerState. Do not store mutable parameters as part of fedjax.FederatedAlgorithm.

  2. init: Initializes the server state.

  3. apply: Takes the ServerState and a set of client_ids, corresponding datasets, and random keys and returns a new ServerState along with any information we need from the clients in the form of client_diagnostics.

We demonstrate fedjax.FederatedAlgorithm using Federated Averaging (FedAvg) and the emnist dataset. We first initialize the model, datasets and the federated algorithm.

train, test = fedjax.datasets.emnist.load_data(only_digits=False)
model = fedjax.models.emnist.create_conv_model(only_digits=False)

rng = jax.random.PRNGKey(0)
init_params = model.init(rng)
# Federated algorithm requires a gradient function, client optimizer, 
# server optimizers, and hyperparameters for batching at the client level.
grad_fn = fedjax.model_grad(model)
client_optimizer = fedjax.optimizers.sgd(0.1)
server_optimizer = fedjax.optimizers.sgd(1.0)
batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=10)
fed_alg = fedjax.algorithms.fed_avg.federated_averaging(grad_fn,

Note that similar to the rest of the library, we only pass on the necessary functions and parameters to the federated algorithm. Hence, to initialize the federated algorithm, we only passed the grad_fn and did not pass the entire model. With this, we now initialize the server state.

init_server_state = fed_alg.init(init_params)

To run the federated algorithm, we pass the server state and client data to the apply function. For this end, we pass client data as a tuple of client id, client data, and the random keys. Adding client ids and random keys has multiple advantages. Firstly client client ids allows to track client diagnostics and would be helpful in debugging. Passing random keys would ensure deterministic execution and allow repeatability. Furthermore, as we discuss later it would help us with fast implementations. We first format the data in this necessary format and then run one round of federated learning.

# Select 5 client_ids and their data
client_ids = list(train.client_ids())[:5]
clients_ids_and_data = list(train.get_clients(client_ids))

client_inputs = []
for i in range(5):
  rng, use_rng = jax.random.split(rng)
  client_id, client_data = clients_ids_and_data[i]
  client_inputs.append((client_id, client_data, use_rng))

updated_server_state, client_diagnostics = fed_alg.apply(init_server_state,
# Prints the l2 norm of gradients as part of client_diagnostics. 
{b'002d084c082b8586:f0185_23': {'delta_l2_norm': DeviceArray(1.9278834, dtype=float32)}, b'005fdad281234bc0:f0151_02': {'delta_l2_norm': DeviceArray(1.8239512, dtype=float32)}, b'014c177da5b15a39:f1565_04': {'delta_l2_norm': DeviceArray(1.6514685, dtype=float32)}, b'0156df0c34a25944:f3772_10': {'delta_l2_norm': DeviceArray(1.5863262, dtype=float32)}, b'01725f8a648ceeb6:f3408_47': {'delta_l2_norm': DeviceArray(1.613201, dtype=float32)}}

As we see above, the client statistics provide the delta_l2_norm of the gradients for each client, which can be potentially used for debugging purposes.

Writing federated algorithms

With this background on how to use existing implementations, we are now going to describe how to write your own federated algorithms in FedJAX. As discussed above, this involves three steps:

  1. Client selection

  2. Client update

  3. Server aggregation

Client selection

At each round of federated learning, typically clients are sampled uniformly at random. This can be done using numpy as follows.

all_client_ids = list(train.client_ids())
print("Total number of client ids: ", len(all_client_ids))

sampled_client_ids = np.random.choice(all_client_ids, size=2, replace=False)
print("Sampled client ids: ", sampled_client_ids)
Total number of client ids:  3400
Sampled client ids:  [b'3abf53413107a36f:f3901_45' b'7b15a2a43e2c5097:f0452_37']

However, the above code is not desirable due to the following reasons:

  1. For reproducibility, it is desirable to have a fixed seed just for sampling clients.

  2. Across rounds, different clients need to be sampled.

  3. For I/O efficiency reasons, it might be better to do an approximately uniform sampling, where clients whose data is stored together are sampled together.

  4. Federated algorithms typically require additional randomness for batching, or dropout that needs to be sent to clients.

To incorporate these features, FedJAX provides a few client samplers.

  1. fedjax.client_samplers.UniformShuffledClientSampler

  2. fedjax.client_samplers.UniformGetClientSampler

fedjax.client_samplers.UniformShuffledClientSampler is preferred for efficiency reasons, but if we need to sample clients truly randomly, fedjax.client_samplers.UniformGetClientSampler can be used. Both of them have a sample function that returns a list of client_ids, client_data, and client_rng.

efficient_sampler = fedjax.client_samplers.UniformShuffledClientSampler(
    train.shuffled_clients(buffer_size=100), num_clients=2)
print("Sampling from the efficient sampler.")
for round in range(3):
  sampled_clients_with_data = efficient_sampler.sample()
  for client_id, client_data, client_rng in sampled_clients_with_data:
    print(round, client_id)

perfect_uniform_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
print("Sampling from the perfect uniform sampler.")
for round in range(3):
  sampled_clients_with_data = perfect_uniform_sampler.sample()
  for client_id, client_data, client_rng in sampled_clients_with_data:
    print(round, client_id)
Sampling from the efficient sampler.
0 b'1ec29e39b10521aa:f3848_28'
0 b'0156df0c34a25944:f3772_10'
1 b'0e9e55d97351b8a6:f1424_45'
1 b'1cb4c2b15c501e91:f0936_34'
2 b'1ef43f1933de69ae:f2086_25'
2 b'0a544c86e9731fc1:f0629_39'
Sampling from the perfect uniform sampler.
0 b'0eb0b8eb6ab0fcd6:f0174_06'
0 b'2b766db84ee57ba6:f2380_76'
1 b'b268b64fea71f3a7:f0766_01'
1 b'786d98ea95e5a59b:f1729_47'
2 b'35b6c7951d56d353:f3574_30'
2 b'98355b0cd97cdbb8:f3990_02'

Client update

After selecting the clients, the next step would be running a model update step in the clients. Typically this is done by running a few epochs of SGD. We only pass parts of the algorithm that are necessary for the client update.

The client update typically requires a set of parameters from the server (init_params in this example), the client dataset, and a source of randomness (rng). The randomness can be used for dropout or other model update steps. Finally, instead of passing the entire model to the client update, since our code only depends on the gradient function, we pass grad_fn to client_update.

def client_update(init_params, client_dataset, client_rng, grad_fn):
  opt_state = client_optimizer.init(init_params)
  params = init_params
  for batch in client_dataset.shuffle_repeat_batch(batch_size=10):
    client_rng, use_rng = jax.random.split(client_rng)
    grads = grad_fn(params, batch, use_rng)
    opt_state, params = client_optimizer.apply(grads, opt_state, params)
  delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
                                             init_params, params)
  return delta_params, len(client_dataset)

client_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
for client_id, client_data, client_rng in sampled_clients_with_data:
  delta_params, num_samples = client_update(init_params,client_data, 
                                            client_rng, grad_fn)
  print(client_id, num_samples, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' 348 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' 152 KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Server aggregation

The outputs of the clients are typically aggregated by computing the weighted mean of the updates, where the weight is the number of client examples. This can be easily done by using the fedjax.tree_util.tree_mean function.

sampled_clients_with_data = client_sampler.sample()
client_updates = []
for client_id, client_data, client_rng in sampled_clients_with_data:
  delta_params, num_samples = client_update(init_params, client_data,
                                            client_rng, grad_fn)
  client_updates.append((delta_params, num_samples))
updated_output = fedjax.tree_util.tree_mean(client_updates)
KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Combing the above steps gives the FedAvg algorithm, which can be found in the example FedJAX implementation of FedAvg..

Efficient implementation

The above implementation would be efficient enough for running on single machines. However, JAX provides primitives such as jax.pmap and jax.vmap for efficient parallelization across multiple accelerators. FedJAX provides support for them in federated learning by distributing client computation across several accelerators.

To take advantage of the faster implementation, we need to implement client_update in a specific format. It has three functions:

  1. client_init

  2. client_step

  3. client_final


This function takes the inputs from the server and outputs a client_step_state which will be passed in between client steps. It is desirable for the client_step_state to be a dictionary. In this example, it just copies the parameters, optimizer_state and the current state of client randomness.

We can think of the inputs from the server as “shared inputs” that are shared across all clients and the client_step_state as client-specific inputs that are separate per client.

def client_init(server_params, client_rng):
  opt_state = client_optimizer.init(server_params)
  client_step_state = {
      'params': server_params,
      'opt_state': opt_state,
      'rng': client_rng,
  return client_step_state


client_step takes the current client_step_state and a batch of examples and updates the client_step_state. In this example, we run one step of SGD using the batch of examples and update client_step_state to reflect the new parameters, optimization state, and randomness.

def client_step(client_step_state, batch):
  rng, use_rng = jax.random.split(client_step_state['rng'])
  grads = grad_fn(client_step_state['params'], batch, use_rng)
  opt_state, params = client_optimizer.apply(grads,
  next_client_step_state = {
      'params': params,
      'opt_state': opt_state,
      'rng': rng,
  return next_client_step_state


client_final modifies the final client_step_state and returns the desired parameters. In this example, we compute the difference between the initial parameters and the final updated parameters in the client_final function.

def client_final(server_params, client_step_state):
  delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
  return delta_params


Once we have these three functions, we can combine them to create a client_update function using the fedjax.for_each_client function. fedjax.for_each_client returns a function that can be used to run client updates. The sample usage is below.

for_each_client_update = fedjax.for_each_client(client_init,

client_sampler = fedjax.client_samplers.UniformGetClientSampler(
    train, num_clients=2, seed=1)
sampled_clients_with_data = client_sampler.sample()
batched_clients_data = [
      (cid, cds.shuffle_repeat_batch(batch_size=10), crng)
      for cid, cds, crng in sampled_clients_with_data
for client_id, delta_params in for_each_client_update(init_params,
  print(client_id, delta_params.keys())
b'0eb0b8eb6ab0fcd6:f0174_06' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])
b'2b766db84ee57ba6:f2380_76' KeysOnlyKeysView(['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1'])

Note that for_each_client_update requires the client data to be already batched. This is necessary for performance gains while using multiple accelerators. Furthermore, the batch size needs to be the same across all clients.

By default fedjax.for_each_client selects the standard JIT backend. To enable parallelism with TPUs or for debugging, we can set it using fedjax.set_for_each_client_backend(backend), where backend is either ‘pmap’ or ‘debug’, respectively.

The for each client function can also be used to add some additional step wise results, which can be used for debugging. This requires changing the client_step function.

def client_step_with_log(client_step_state, batch):
  rng, use_rng = jax.random.split(client_step_state['rng'])
  grads = grad_fn(client_step_state['params'], batch, use_rng)
  opt_state, params = client_optimizer.apply(grads,
  next_client_step_state = {
      'params': params,
      'opt_state': opt_state,
      'rng': rng,
  grad_norm = fedjax.tree_util.tree_l2_norm(grads)
  return next_client_step_state, grad_norm

for_each_client_update = fedjax.for_each_client(
    client_init, client_step_with_log, client_final, with_step_result=True)

for client_id, delta_params, grad_norms in for_each_client_update(
    init_params, batched_clients_data):

  print(client_id, list(delta_params.keys()))
  print(client_id, np.array(grad_norms))
b'0eb0b8eb6ab0fcd6:f0174_06' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'0eb0b8eb6ab0fcd6:f0174_06' [4.599997  3.9414525 4.8312078 4.450683  5.971922  2.4694602 2.6183944
 2.1996734 2.7145145 2.9750984 3.0633514 2.4050198 2.612233  2.672571
 3.1303792 3.236007  3.3968801 2.986587  2.5775976 2.8625555 3.1062818
 4.4250994 2.7431202 3.2192783 2.7670481 3.6075711 3.7296255 5.190155
 3.4366677 4.5394745 3.2277424 3.1362765 2.8626535 3.7905648 3.5686817]
b'2b766db84ee57ba6:f2380_76' ['conv_dropout_module/conv2_d', 'conv_dropout_module/conv2_d_1', 'conv_dropout_module/linear', 'conv_dropout_module/linear_1']
b'2b766db84ee57ba6:f2380_76' [4.3958793 4.4705005 4.84529   4.4623365 5.3809257 5.4818144 3.0325508
 4.1576953 8.666909  2.067883  2.1935403 2.5095372 2.2202325 2.8493588
 3.2463503 3.6446893]


In this tutorial, we have covered the following:

  • Using exisiting algorithms in fedjax.algorithms.

  • Writing new algorithms using fedjax.FederatedAlgorithm.

  • Efficient implementation using fedjax.for_each_client in the presence of accelerators.