# Working with models in FedJAX¶

In this chapter, we will learn about `fedjax.Model`

. This notebook assumes you already have finished the “Datasets” chapter. We first overview centralized training and evaluation with `fedjax.Model`

and then describe how to add new neural architectures and specify additional evaluation metrics.

```
# Uncomment these to install fedjax.
# !pip install fedjax
# !pip install --upgrade git+https://github.com/google/fedjax.git
```

```
import itertools
import jax
import jax.numpy as jnp
from jax.example_libraries import stax
import fedjax
```

## Centralized training & evaluation with `fedjax.Model`

¶

Most federated learning algorithms are built upon common components from standard centralized learning. `fedjax.Model`

holds these common components. In centralized learning, we are mostly concerned with two tasks:

Training: We want to optimize our model parameters on the training dataset.

Evaluation: We want to know the values of evaluation metrics (e.g. accuracy) of the current model parameters on a test dataset.

Let’s first see how we can carry out these two tasks on the EMNIST dataset with `fedjax.Model`

.

```
# Load train/test splits of the EMNIST dataset.
train, test = fedjax.datasets.emnist.load_data()
# As a start, let's simply use a logistic regression model.
model = fedjax.models.emnist.create_logistic_model()
```

### Random initialization, the JAX way¶

To start training, we need some randomly initialized parameters. In JAX, pseudo random number generation works slightly differently. For now, it is sufficient to know we call `jax.random.PRNGKey()`

to seed the random number generator. JAX has a detailed introduction on this topic, if you are interested.

To create the initial model parameters, we simply call `fedjax.Model.init()`

with a `PRNGKey`

.

```
params_rng = jax.random.PRNGKey(0)
params = model.init(params_rng)
```

```
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
```

Here are our initial model parameters. With the same `PRNGKey`

, we will always get the same random initialization. There are 2 parameters in our model, the weights `w`

, and the bias `b`

. They are organized into a `FlapMapping`

, but in general any PyTree can be used to store model parameters.

```
params
```

```
FlatMapping({
'linear': FlatMapping({
'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0.], dtype=float32),
'w': DeviceArray([[-0.04067196, 0.02348138, -0.0214883 , ..., 0.01055492,
-0.06988288, -0.02952586],
[-0.03985253, -0.03804361, 0.01401524, ..., 0.02281437,
-0.01771905, 0.06676884],
[ 0.00098182, -0.00844628, 0.01303554, ..., -0.05299249,
0.01777634, -0.0006488 ],
...,
[-0.05691862, 0.05192501, 0.01588603, ..., 0.0157204 ,
-0.01854135, 0.00297953],
[ 0.01680706, 0.05579231, 0.0459589 , ..., 0.01990358,
-0.01944044, -0.01710149],
[-0.00880739, 0.04229043, 0.00998938, ..., -0.00633441,
-0.04824542, 0.01395545]], dtype=float32),
}),
})
```

### Evaluating model parameters¶

Before we start training, let’s first see how our initial parameters fare on the train and test sets. Unsurprisingly, they do not do very well. We evaluate using the `fedjax.evaluate_model()`

which takes in model, parameters, and datasets which are batched. As noted in the dataset tutorial, we batch using
`fedjax.padded_batch_federated_data()`

for efficiency. `fedjax.padded_batch_federated_data()`

is very similar to `fedjax.ClientDataset.padded_batch()`

but operates over the entire federated dataset.

```
# We select first 16 batches using itertools.islice.
batched_test_data = list(itertools.islice(
fedjax.padded_batch_federated_data(test, batch_size=128), 16))
batched_train_data = list(itertools.islice(
fedjax.padded_batch_federated_data(train, batch_size=128), 16))
print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
```

```
eval_test {'accuracy': DeviceArray(0.01757812, dtype=float32), 'loss': DeviceArray(4.1253214, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.02490234, dtype=float32), 'loss': DeviceArray(4.116228, dtype=float32)}
```

How does our model know what evaluation metrics to report? It is simply specified in the `eval_metrics`

field. We will discuss evaluation metrics in more detail later.

```
model.eval_metrics
```

```
{'accuracy': Accuracy(target_key='y', pred_key=None),
'loss': CrossEntropyLoss(target_key='y', pred_key=None)}
```

Since `fedjax.evaluate_model()`

simply takes a stream of batches, we can also use it to evaluate multiple clients.

```
for client_id, dataset in itertools.islice(test.clients(), 4):
print(
client_id,
fedjax.evaluate_model(model, params,
dataset.padded_batch(batch_size=128)))
```

```
b'002d084c082b8586:f0185_23' {'accuracy': DeviceArray(0.05, dtype=float32), 'loss': DeviceArray(4.1247168, dtype=float32)}
b'005fdad281234bc0:f0151_02' {'accuracy': DeviceArray(0.09375, dtype=float32), 'loss': DeviceArray(4.093891, dtype=float32)}
b'014c177da5b15a39:f1565_04' {'accuracy': DeviceArray(0., dtype=float32), 'loss': DeviceArray(4.127692, dtype=float32)}
b'0156df0c34a25944:f3772_10' {'accuracy': DeviceArray(0.05263158, dtype=float32), 'loss': DeviceArray(4.1521378, dtype=float32)}
```

### The training objective¶

To train our model, we need two things: the objective function to minimize and an optimizer.

`fedjax.Model`

contains two functions that can be used to arrive at the training objective:

`apply_for_train(params, batch_example, rng)`

takes the current model parameters, a batch of examples, and a`PRNGKey`

, and returns some output.`train_loss(batch_example, train_output)`

translates the output of`apply_for_train()`

into a vector of per-example loss values.

In our example model, `apply_for_train()`

produces a score for each class and `train_loss()`

is simply the cross entropy loss. `apply_for_train()`

in this case does not make use of a `PRNGKey`

, so we can pass `None`

instead for convenience. A different `apply_for_train()`

might actually make use of the `PRNGKey`

, for tasks such as dropout.

```
# train_batches is an infinite stream of shuffled batches of examples.
def train_batches():
return fedjax.shuffle_repeat_batch_federated_data(
train,
batch_size=8,
client_buffer_size=16,
example_buffer_size=1024,
seed=0)
# We obtain the first batch by using the `next` function.
example = next(train_batches())
output = model.apply_for_train(params, example, None)
per_example_loss = model.train_loss(example, output)
output.shape, per_example_loss
```

```
((8, 62), DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
4.209843 , 4.060939 , 4.19899 ], dtype=float32))
```

Note that the `output`

is per example predictions and has shape (8, 62), where 8 is the batch size and 62 is the number of classes. Alternatively, we can use `model_per_example_loss()`

to get a function that gives us the same result. `model_per_example_loss()`

is a convenience function that does exactly what we just did.

```
per_example_loss_fn = fedjax.model_per_example_loss(model)
per_example_loss_fn(params, example, None)
```

```
DeviceArray([4.0337796, 4.046219 , 3.9447758, 3.933005 , 4.116893 ,
4.209843 , 4.060939 , 4.19899 ], dtype=float32)
```

The training objective is a scalar, so why does `train_loss()`

return a vector of per-example loss values? First of all, the training objective in most cases is just the average of the per-example loss values, so arriving at the final training objective isn’t hard. Moreover, in certain algorithms, we not only use the train loss over a single batch of examples for a stochastic training step, but also need to estimate the average train loss over an entire (client) dataset. Having the per-example loss values there is instrumental in obtaining the correct estimate when the batch sizes may vary.

```
def train_objective(params, example):
return jnp.mean(per_example_loss_fn(params, example, None))
train_objective(params, example)
```

```
DeviceArray(4.0680556, dtype=float32)
```

### Optimizers¶

With the training objective at hand, we just need an optimizer to find some good model parameters that minimize it.

There are many optimizer implementations in JAX out there, but FedJAX doesn’t force one choice over any other. Instead, FedJAX provides a simple `fedjax.optimizers.Optimizer`

interface so a new optimizer implementation can be wrapped. For convenience, FedJAX provides some common optimizers wrapped from optax.

```
optimizer = fedjax.optimizers.adam(1e-3)
```

An optimizer is simply a pair of two functions:

`init(params)`

returns the initial optimizer state, such as initial values for accumulators of gradients.`apply(grads, opt_state, params)`

applies the gradients to update the current optimizer state and model parameters.

Instead of modifying `opt_state`

or `params`

, `apply()`

returns a new pair of optimizer state and model parameters. In JAX, it is common to express computations in this stateless/mutation free style, often referred to as functional programming, or pure functions. The pureness of functions is crucial to many features in JAX, so it is always good practice to write functions that do not modify its inputs. You have probably also noticed that all the functions of `fedjax.Model`

we have seen so far do not modify the model object itself (for example, `init()`

returns model parameters instead of setting some attribute of `model`

; `apply_for_train()`

takes model parameters as an input argument, instead of getting it from `model`

). FedJAX does this to keep all functions pure.

However, in the top level training loop, it is fine to mutate states since we are not in a function that may be transformed by JAX. Let’s run our first training step, which resulted in a slight decrease in objective on the same batch of examples.

To obtain the gradients, we use `jax.grad()`

which returns the gradient function. More details about `jax.grad()`

can be found from the JAX documentation.

```
opt_state = optimizer.init(params)
grads = jax.grad(train_objective)(params, example)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
```

```
DeviceArray(4.0080366, dtype=float32)
```

Instead of using `jax.grad()`

directly, FedJAX also provides a convenient `fedjax.model_grad()`

which computes the gradient of a model with respect to the averaged `fedjax.model_per_example_loss()`

.

```
model_grads = fedjax.model_grad(model)(params, example, None)
opt_state, params = optimizer.apply(grads, opt_state, params)
train_objective(params, example)
```

```
DeviceArray(3.9482572, dtype=float32)
```

Let’s wrap everything into a single JIT compiled function and train a few more steps, and evaluate again.

```
@jax.jit
def train_step(example, opt_state, params):
grads = jax.grad(train_objective)(params, example)
return optimizer.apply(grads, opt_state, params)
for example in itertools.islice(train_batches(), 5000):
opt_state, params = train_step(example, opt_state, params)
print('eval_test', fedjax.evaluate_model(model, params, batched_test_data))
print('eval_train', fedjax.evaluate_model(model, params, batched_train_data))
```

```
eval_test {'accuracy': DeviceArray(0.6152344, dtype=float32), 'loss': DeviceArray(1.5562292, dtype=float32)}
eval_train {'accuracy': DeviceArray(0.59765625, dtype=float32), 'loss': DeviceArray(1.6278805, dtype=float32)}
```

## Building a custom model¶

`fedjax.Model`

was designed with customization in mind. We have already seen how to switch to a different training loss. In this section, we will discuss how the rest of a `fedjax.Model`

can be customized.

### Training loss¶

Because `train_loss()`

is separate from `apply_for_train()`

, it is easy to switch to a different loss function.

```
def hinge_loss(example, output):
label = example['y']
num_classes = output.shape[-1]
mask = jax.nn.one_hot(label, num_classes)
label_score = jnp.sum(output * mask, axis=-1)
best_score = jnp.max(output + 1 - mask, axis=-1)
return best_score - label_score
hinge_model = model.replace(train_loss=hinge_loss)
fedjax.model_per_example_loss(hinge_model)(params, example, None)
```

```
DeviceArray([4.306656 , 0. , 0. , 0.4375435 , 0.96986485,
0. , 0.3052401 , 1.3918507 ], dtype=float32)
```

### Evaluation metrics¶

We have already seen that the `eval_metrics`

field of a `fedjax.Model`

tells the model what metrics to evaluate. `eval_metrics`

is a mapping from metric names to `fedjax.metrics.Metric`

objects. A `fedjax.metrics.Metric`

object tells us how to calculate a metric’s value from multiple batches of examples. Like `fedjax.Model`

, a `fedjax.metrics.Metric`

is stateless.

To customize the metrics to evaluate on, or what names to give to each, simply specify a different mapping.

```
only_accuracy = model.replace(
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
fedjax.evaluate_model(only_accuracy, params, batched_test_data)
```

```
{'accuracy': DeviceArray(0.6152344, dtype=float32)}
```

There are already some concrete `Metric`

s in `fedjax.metrics`

. It is also easy to implement a new one. You can read more about how to implement a `Metric`

in its own introduction.

The bit of `fedjax.Model`

that is directly relevant to evaluation is `apply_for_eval()`

. The relation between `apply_for_eval()`

and an evaluation metric is similar to that between `apply_for_train()`

and `train_loss()`

: `apply_for_eval(params, example)`

takes the model parameters and a batch of examples (notice there is no randomness in evaluation so we don’t need a `PRNGKey`

), and produces some prediction that evaluation metrics can consume. In our example, the outputs from `apply_for_eval()`

and `apply_for_train()`

are identical, but they don’t have to be.

```
jnp.all(
model.apply_for_train(params, example, None) == model.apply_for_eval(
params, example))
```

```
DeviceArray(True, dtype=bool)
```

What `apply_for_eval()`

needs to produce really just depends on what evaluation `fedjax.metrics.Metric`

s will be used. In our case, we are using `fedjax.metrics.Accuracy`

, and `fedjax.metrics.CrossEntropyLoss`

. They are similar in their requirements on the inputs:

They both need to know the true label from the

`example`

, using a`target_key`

that defaults to`"y"`

.They both need to know the predicted scores from

`apply_for_eval()`

, customizable as`pred_key`

. If`pred_key`

is None,`apply_for_eval()`

should return just a vector of per-class scores; otherwise`pred_key`

can be a string key, and`apply_for_eval()`

should return a mapping (e.g.`dict`

) that maps the key to a vector of per-class scores.

```
fedjax.metrics.Accuracy()
```

```
Accuracy(target_key='y', pred_key=None)
```

### Neural network architectures¶

We have now covered all five parts of a `fedjax.Model`

, namely `init()`

, `apply_for_train()`

, `apply_for_eval()`

, `train_loss()`

, and `eval_metrics`

. `train_loss()`

and `eval_metrics`

are easy to customize since they are mostly agnostic to the actual neural network architecture of the model. `init()`

, `apply_for_train()`

, and `apply_for_eval()`

on the other hand, are closely related.

In principle, as long as these three functions meet the interface we have seen so far, they can be used to build a custom model. Let’s try to build a model that uses multi-layer perceptron and hinge loss.

```
def cross_entropy_loss(example, output):
label = example['y']
num_classes = output.shape[-1]
mask = jax.nn.one_hot(label, num_classes)
return -jnp.sum(jax.nn.log_softmax(output) * mask, axis=-1)
def mlp_model(num_input_units, num_units, num_classes):
def mlp_init(rng):
w0_rng, w1_rng = jax.random.split(rng)
w0 = jax.random.uniform(w0_rng, [num_input_units, num_units])
b0 = jnp.zeros([num_units])
w1 = jax.random.uniform(w1_rng, [num_units, num_classes])
b1 = jnp.zeros([num_classes])
return w0, b0, w1, b1
def mlp_apply(params, batch, rng=None):
w0, b0, w1, b1 = params
x = batch['x']
batch_size = x.shape[0]
h = jax.nn.relu(x.reshape([batch_size, -1]) @ w0 + b0)
return h @ w1 + b1
return fedjax.Model(
init=mlp_init,
apply_for_train=mlp_apply,
apply_for_eval=mlp_apply,
train_loss=cross_entropy_loss,
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
# There are 28*28 input pixels, and 62 classes in EMNIST.
mlp = mlp_model(28 * 28, 128, 62)
@jax.jit
def mlp_train_step(example, opt_state, params):
@jax.grad
def grad_fn(params, example):
return jnp.mean(fedjax.model_per_example_loss(mlp)(params, example, None))
grads = grad_fn(params, example)
return optimizer.apply(grads, opt_state, params)
params = mlp.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
fedjax.evaluate_model(mlp, params, batched_test_data))
for example in itertools.islice(train_batches(), 5000):
opt_state, params = mlp_train_step(example, opt_state, params)
print('eval_test after training:',
fedjax.evaluate_model(mlp, params, batched_test_data))
```

```
eval_test before training: {'accuracy': DeviceArray(0.05078125, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.4951172, dtype=float32)}
```

While writing custom neural network architectures from scratch is possible, most of the time, it is much more convenient to use a neural network library such as Haiku or `jax.example_libraries.stax`

. The two functions `fedjax.create_model_from_haiku`

and `fedjax.create_model_from_stax`

can convert a neural network expressed in the respective framework into a `fedjax.Model`

. Let’s build a convolutional network using `jax.example_libraries.stax`

this time.

```
def stax_cnn_model(input_shape, num_classes):
stax_init, stax_apply = stax.serial(
stax.Conv(
out_chan=64, filter_shape=(3, 3), strides=(1, 1), padding='SAME'),
stax.Relu,
stax.Flatten,
stax.Dense(256),
stax.Relu,
stax.Dense(num_classes),
)
return fedjax.create_model_from_stax(
stax_init=stax_init,
stax_apply=stax_apply,
sample_shape=input_shape,
train_loss=cross_entropy_loss,
eval_metrics={'accuracy': fedjax.metrics.Accuracy()})
stax_cnn = stax_cnn_model([-1, 28, 28, 1], 62)
@jax.jit
def stax_cnn_train_step(example, opt_state, params):
@jax.grad
def grad_fn(params, example):
return jnp.mean(
fedjax.model_per_example_loss(stax_cnn)(params, example, None))
grads = grad_fn(params, example)
return optimizer.apply(grads, opt_state, params)
params = stax_cnn.init(jax.random.PRNGKey(0))
opt_state = optimizer.init(params)
print('eval_test before training:',
fedjax.evaluate_model(stax_cnn, params, batched_test_data))
for example in itertools.islice(train_batches(), 1000):
opt_state, params = stax_cnn_train_step(example, opt_state, params)
print('eval_test after training:',
fedjax.evaluate_model(stax_cnn, params, batched_test_data))
```

```
eval_test before training: {'accuracy': DeviceArray(0.03076172, dtype=float32)}
eval_test after training: {'accuracy': DeviceArray(0.72558594, dtype=float32)}
```

## Recap¶

In this chapter, we have covered the following:

Components of

`fedjax.Model`

:`init()`

,`apply_for_train()`

,`apply_for_eval()`

,`train_loss()`

, and`eval_metrics`

.Optimizers in

`fedjax.optimizers`

.Standard centralized learning with a

`fedjax.Model`

.Specifying evaluation metrics in

`eval_metrics`

.Building a custom

`fedjax.Model`

.