Working with models in FedJAX

Open in Colab

In this chapter, we will learn about fedjax.Model. This notebook assumes you already have finished the “Datasets” chapter. We first overview centralized training and evaluation with fedjax.Model and then describe how to add new neural architectures and specify additional evaluation metrics.

# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+
import itertools

import jax
import jax.numpy as jnp
from jax.example_libraries import stax

import fedjax

Centralized training & evaluation with fedjax.Model

Most federated learning algorithms are built upon common components from standard centralized learning. fedjax.Model holds these common components. In centralized learning, we are mostly concerned with two tasks:

  • Training: We want to optimize our model parameters on the training dataset.

  • Evaluation: We want to know the values of evaluation metrics (e.g. accuracy) of the current model parameters on a test dataset.

Let’s first see how we can carry out these two tasks on the EMNIST dataset with fedjax.Model.

# Load train/test splits of the EMNIST dataset.
train, test = fedjax.datasets.emnist.load_data()

# As a start, let's simply use a logistic regression model.
model = fedjax.models.emnist.create_logistic_model()

Random initialization, the JAX way

To start training, we need some randomly initialized parameters. In JAX, pseudo random number generation works slightly differently. For now, it is sufficient to know we call jax.random.PRNGKey() to seed the random number generator. JAX has a detailed introduction on this topic, if you are interested.

To create the initial model parameters, we simply call fedjax.Model.init() with a PRNGKey.

params_rng = jax.random.PRNGKey(0)
params = model.init(params_rng)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Here are our initial model parameters. With the same PRNGKey, we will always get the same random initialization. There are 2 parameters in our model, the weights w, and the bias b. They are organized into a FlapMapping, but in general any PyTree can be used to store model parameters.

  'linear': FlatMapping({
              'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                0., 0.], dtype=float32),
              'w': DeviceArray([[-0.04067196,  0.02348138, -0.0214883 , ...,  0.01055492,
                                 -0.06988288, -0.02952586],
                                [-0.03985253, -0.03804361,  0.01401524, ...,  0.02281437,
                                 -0.01771905,  0.06676884],
                                [ 0.00098182, -0.00844628,  0.01303554, ..., -0.05299249,
                                  0.01777634, -0.0006488 ],
                                [-0.05691862,  0.05192501,  0.01588603, ...,  0.0157204 ,
                                 -0.01854135,  0.00297953],
                                [ 0.01680706,  0.05579231,  0.0459589 , ...,  0.01990358,
                                 -0.01944044, -0.01710149],
                                [-0.00880739,  0.04229043,  0.00998938, ..., -0.00633441,
                                 -0.04824542,  0.01395545]], dtype=float32),

Evaluating model parameters

Before we start training, let’s first see how our initial parameters fare on the train and test sets. Unsurprisingly, they do not do very well. We evaluate using the fedjax.evaluate_model() which takes in model, parameters, and datasets which are batched. As noted in the dataset tutorial, we batch using fedjax.padded_batch_federated_data() for efficiency. fedjax.padded_batch_federated_data() is very similar to fedjax.ClientDataset.padded_batch() but operates over the entire federated dataset.

# We select first 16 batches using itertools.islice.
batched_test_data = list(itertools.islice(
    fedjax.padded_batch_federated_data(test, batch_size=128), 16))
batched_train_data = list(itertools.islice(
    fedjax.padded_batch_federated_data(train, batch_size=128), 16))

print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.01757812, dtype=float32), 'loss': DeviceArray(4.1253214, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.02490234, dtype=float32), 'loss': DeviceArray(4.116228, dtype=float32)}

How does our model know what evaluation metrics to report? It is simply specified in the eval_metrics field. We will discuss evaluation metrics in more detail later.

{'accuracy': Accuracy(target_key='y', pred_key=None),
 'loss': CrossEntropyLoss(target_key='y', pred_key=None)}

Since fedjax.evaluate_model() simply takes a stream of batches, we can also use it to evaluate multiple clients.

for client_id, dataset in itertools.islice(test.clients(), 4):
      fedjax.evaluate_model(model, params,
b'002d084c082b8586:f0185_23' {'accuracy': DeviceArray(0.05, dtype=float32), 'loss': DeviceArray(4.1247168, dtype=float32)}
b'005fdad281234bc0:f0151_02' {'accuracy': DeviceArray(0.09375, dtype=float32), 'loss': DeviceArray(4.093891, dtype=float32)}
b'014c177da5b15a39:f1565_04' {'accuracy': DeviceArray(0., dtype=float32), 'loss': DeviceArray(4.127692, dtype=float32)}
b'0156df0c34a25944:f3772_10' {'accuracy': DeviceArray(0.05263158, dtype=float32), 'loss': DeviceArray(4.1521378, dtype=float32)}

The training objective

To train our model, we need two things: the objective function to minimize and an optimizer.

fedjax.Model contains two functions that can be used to arrive at the training objective:

  • apply_for_train(params, batch_example, rng) takes the current model parameters, a batch of examples, and a PRNGKey, and returns some output.

  • train_loss(batch_example, train_output) translates the output of apply_for_train() into a vector of per-example loss values.

In our example model, apply_for_train() produces a score for each class and train_loss() is simply the cross entropy loss. apply_for_train() in this case does not make use of a PRNGKey, so we can pass None instead for convenience. A different apply_for_train() might actually make use of the PRNGKey, for tasks such as dropout.

# train_batches is an infinite stream of shuffled batches of examples.
def train_batches():
  return fedjax.shuffle_repeat_batch_federated_data(

# We obtain the first batch by using the `next` function.
example = next(train_batches())
output = model.apply_for_train(params, example, None)
per_example_loss = model.train_loss(example, output)

output.shape, per_example_loss
((8, 62), DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
              4.209843 , 4.060939 , 4.19899  ], dtype=float32))

Note that the output is per example predictions and has shape (8, 62), where 8 is the batch size and 62 is the number of classes. Alternatively, we can use model_per_example_loss() to get a function that gives us the same result. model_per_example_loss() is a convenience function that does exactly what we just did.

per_example_loss_fn = fedjax.model_per_example_loss(model)
per_example_loss_fn(params, example, None)
DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
             4.209843 , 4.060939 , 4.19899  ], dtype=float32)

The training objective is a scalar, so why does train_loss() return a vector of per-example loss values? First of all, the training objective in most cases is just the average of the per-example loss values, so arriving at the final training objective isn’t hard. Moreover, in certain algorithms, we not only use the train loss over a single batch of examples for a stochastic training step, but also need to estimate the average train loss over an entire (client) dataset. Having the per-example loss values there is instrumental in obtaining the correct estimate when the batch sizes may vary.

def train_objective(params, example):
  return jnp.mean(per_example_loss_fn(params, example, None))

train_objective(params, example)
DeviceArray(4.0680556, dtype=float32)


With the training objective at hand, we just need an optimizer to find some good model parameters that minimize it.

There are many optimizer implementations in JAX out there, but FedJAX doesn’t force one choice over any other. Instead, FedJAX provides a simple fedjax.optimizers.Optimizer interface so a new optimizer implementation can be wrapped. For convenience, FedJAX provides some common optimizers wrapped from optax.

optimizer = fedjax.optimizers.adam(1e-3)

An optimizer is simply a pair of two functions:

  • init(params) returns the initial optimizer state, such as initial values for accumulators of gradients.

  • apply(grads, opt_state, params) applies the gradients to update the current optimizer state and model parameters.

Instead of modifying opt_state or params, apply() returns a new pair of optimizer state and model parameters. In JAX, it is common to express computations in this stateless/mutation free style, often referred to as functional programming, or pure functions. The pureness of functions is crucial to many features in JAX, so it is always good practice to write functions that do not modify its inputs. You have probably also noticed that all the functions of fedjax.Model we have seen so far do not modify the model object itself (for example, init() returns model parameters instead of setting some attribute of model; apply_for_train() takes model parameters as an input argument, instead of getting it from model). FedJAX does this to keep all functions pure.

However, in the top level training loop, it is fine to mutate states since we are not in a function that may be transformed by JAX. Let’s run our first training step, which resulted in a slight decrease in objective on the same batch of examples.

To obtain the gradients, we use jax.grad() which returns the gradient function. More details about jax.grad() can be found from the JAX documentation.

opt_state = optimizer.init(params)
grads = jax.grad(train_objective)(params, example)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(4.0080366, dtype=float32)

Instead of using jax.grad() directly, FedJAX also provides a convenient fedjax.model_grad() which computes the gradient of a model with respect to the averaged fedjax.model_per_example_loss().

model_grads = fedjax.model_grad(model)(params, example, None)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
DeviceArray(3.9482572, dtype=float32)

Let’s wrap everything into a single JIT compiled function and train a few more steps, and evaluate again.

def train_step(example, opt_state, params):
  grads = jax.grad(train_objective)(params, example)
  return optimizer.apply(grads, opt_state, params)

for example in itertools.islice(train_batches(), 5000):
  opt_state, params = train_step(example, opt_state, params)

print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
eval_test {'accuracy': DeviceArray(0.6152344, dtype=float32), 'loss': DeviceArray(1.5562292, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.59765625, dtype=float32), 'loss': DeviceArray(1.6278805, dtype=float32)}

Building a custom model

fedjax.Model was designed with customization in mind. We have already seen how to switch to a different training loss. In this section, we will discuss how the rest of a fedjax.Model can be customized.

Training loss

Because train_loss() is separate from apply_for_train(), it is easy to switch to a different loss function.

def hinge_loss(example, output):
  label = example['y']
  num_classes = output.shape[-1]
  mask = jax.nn.one_hot(label, num_classes)
  label_score = jnp.sum(output * mask, axis=-1)
  best_score = jnp.max(output + 1 - mask, axis=-1)
  return best_score - label_score

hinge_model = model.replace(train_loss=hinge_loss)
fedjax.model_per_example_loss(hinge_model)(params, example, None)
DeviceArray([4.306656  , 0.        , 0.        , 0.4375435 , 0.96986485,
             0.        , 0.3052401 , 1.3918507 ], dtype=float32)

Evaluation metrics

We have already seen that the eval_metrics field of a fedjax.Model tells the model what metrics to evaluate. eval_metrics is a mapping from metric names to fedjax.metrics.Metric objects. A fedjax.metrics.Metric object tells us how to calculate a metric’s value from multiple batches of examples. Like fedjax.Model, a fedjax.metrics.Metric is stateless.

To customize the metrics to evaluate on, or what names to give to each, simply specify a different mapping.

only_accuracy = model.replace(
    eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
fedjax.evaluate_model(only_accuracy, params, batched_test_data)
{'accuracy': DeviceArray(0.6152344, dtype=float32)}

There are already some concrete Metrics in fedjax.metrics. It is also easy to implement a new one. You can read more about how to implement a Metric in its own introduction.

The bit of fedjax.Model that is directly relevant to evaluation is apply_for_eval(). The relation between apply_for_eval() and an evaluation metric is similar to that between apply_for_train() and train_loss(): apply_for_eval(params, example) takes the model parameters and a batch of examples (notice there is no randomness in evaluation so we don’t need a PRNGKey), and produces some prediction that evaluation metrics can consume. In our example, the outputs from apply_for_eval() and apply_for_train() are identical, but they don’t have to be.

    model.apply_for_train(params, example, None) == model.apply_for_eval(
        params, example))
DeviceArray(True, dtype=bool)

What apply_for_eval() needs to produce really just depends on what evaluation fedjax.metrics.Metrics will be used. In our case, we are using fedjax.metrics.Accuracy, and fedjax.metrics.CrossEntropyLoss. They are similar in their requirements on the inputs:

  • They both need to know the true label from the example, using a target_key that defaults to "y".

  • They both need to know the predicted scores from apply_for_eval(), customizable as pred_key. If pred_key is None, apply_for_eval() should return just a vector of per-class scores; otherwise pred_key can be a string key, and apply_for_eval() should return a mapping (e.g. dict) that maps the key to a vector of per-class scores.

Accuracy(target_key='y', pred_key=None)

Neural network architectures

We have now covered all five parts of a fedjax.Model, namely init(), apply_for_train(), apply_for_eval(), train_loss(), and eval_metrics. train_loss() and eval_metrics are easy to customize since they are mostly agnostic to the actual neural network architecture of the model. init(), apply_for_train(), and apply_for_eval() on the other hand, are closely related.

In principle, as long as these three functions meet the interface we have seen so far, they can be used to build a custom model. Let’s try to build a model that uses multi-layer perceptron and hinge loss.

def cross_entropy_loss(example, output):
  label = example['y']
  num_classes = output.shape[-1]
  mask = jax.nn.one_hot(label, num_classes)
  return -jnp.sum(jax.nn.log_softmax(output) * mask, axis=-1)

def mlp_model(num_input_units, num_units, num_classes):

  def mlp_init(rng):
    w0_rng, w1_rng = jax.random.split(rng)
    w0 = jax.random.uniform(w0_rng, [num_input_units, num_units])
    b0 = jnp.zeros([num_units])
    w1 = jax.random.uniform(w1_rng, [num_units, num_classes])
    b1 = jnp.zeros([num_classes])
    return w0, b0, w1, b1

  def mlp_apply(params, batch, rng=None):
    w0, b0, w1, b1 = params
    x = batch['x']
    batch_size = x.shape[0]
    h = jax.nn.relu(x.reshape([batch_size, -1]) @ w0 + b0)
    return h @ w1 + b1

  return fedjax.Model(
      eval_metrics={'accuracy': fedjax.metrics.Accuracy()})

# There are 28*28 input pixels, and 62 classes in EMNIST.
mlp = mlp_model(28 * 28, 128, 62)

def mlp_train_step(example, opt_state, params):

  def grad_fn(params, example):
    return jnp.mean(fedjax.model_per_example_loss(mlp)(params, example, None))

  grads = grad_fn(params, example)
  return optimizer.apply(grads, opt_state, params)

params = mlp.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
      fedjax.evaluate_model(mlp, params, batched_test_data))
for example in itertools.islice(train_batches(), 5000):
  opt_state, params = mlp_train_step(example, opt_state, params)
print('eval_test after training:',
      fedjax.evaluate_model(mlp, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.05078125, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.4951172, dtype=float32)}

While writing custom neural network architectures from scratch is possible, most of the time, it is much more convenient to use a neural network library such as Haiku or jax.example_libraries.stax. The two functions fedjax.create_model_from_haiku and fedjax.create_model_from_stax can convert a neural network expressed in the respective framework into a fedjax.Model. Let’s build a convolutional network using jax.example_libraries.stax this time.

def stax_cnn_model(input_shape, num_classes):
  stax_init, stax_apply = stax.serial(
          out_chan=64, filter_shape=(3, 3), strides=(1, 1), padding='SAME'),
  return fedjax.create_model_from_stax(
      eval_metrics={'accuracy': fedjax.metrics.Accuracy()})

stax_cnn = stax_cnn_model([-1, 28, 28, 1], 62)

def stax_cnn_train_step(example, opt_state, params):

  def grad_fn(params, example):
    return jnp.mean(
        fedjax.model_per_example_loss(stax_cnn)(params, example, None))

  grads = grad_fn(params, example)
  return optimizer.apply(grads, opt_state, params)

params = stax_cnn.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
      fedjax.evaluate_model(stax_cnn, params, batched_test_data))
for example in itertools.islice(train_batches(), 1000):
  opt_state, params = stax_cnn_train_step(example, opt_state, params)
print('eval_test after training:',
      fedjax.evaluate_model(stax_cnn, params, batched_test_data))
eval_test before training: {'accuracy': DeviceArray(0.03076172, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.72558594, dtype=float32)}


In this chapter, we have covered the following:

  • Components of fedjax.Model: init(), apply_for_train(), apply_for_eval(), train_loss(), and eval_metrics.

  • Optimizers in fedjax.optimizers.

  • Standard centralized learning with a fedjax.Model.

  • Specifying evaluation metrics in eval_metrics.

  • Building a custom fedjax.Model.