Source code for fedjax.algorithms.fed_avg

# Copyright 2021 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Federated averaging implementation using fedjax.core.

This is the more performant implementation that matches what would be used in
the :py:mod:`fedjax.algorithms.fed_avg` . The key difference between this and
the basic version is the use of :py:mod:`fedjax.core.for_each_client`

Communication-Efficient Learning of Deep Networks from Decentralized Data
    H. Brendan McMahan, Eider Moore, Daniel Ramage,
    Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017.

Adaptive Federated Optimization
    Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush,
    Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021.

from typing import Any, Callable, Mapping, Sequence, Tuple

from fedjax.core import client_datasets
from fedjax.core import dataclasses
from fedjax.core import federated_algorithm
from fedjax.core import federated_data
from fedjax.core import for_each_client
from fedjax.core import optimizers
from fedjax.core import tree_util
from fedjax.core.typing import BatchExample
from fedjax.core.typing import Params
from fedjax.core.typing import PRNGKey

import jax

Grads = Params

def create_train_for_each_client(grad_fn, client_optimizer):
  """Builds client_init, client_step, client_final for for_each_client."""

  def client_init(server_params, client_rng):
    opt_state = client_optimizer.init(server_params)
    client_step_state = {
        'params': server_params,
        'opt_state': opt_state,
        'rng': client_rng,
    return client_step_state

  def client_step(client_step_state, batch):
    rng, use_rng = jax.random.split(client_step_state['rng'])
    grads = grad_fn(client_step_state['params'], batch, use_rng)
    opt_state, params = client_optimizer.apply(grads,
    next_client_step_state = {
        'params': params,
        'opt_state': opt_state,
        'rng': rng,
    return next_client_step_state

  def client_final(server_params, client_step_state):
    delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
    return delta_params

  return for_each_client.for_each_client(client_init, client_step, client_final)

[docs]@dataclasses.dataclass class ServerState: """State of server passed between rounds. Attributes: params: A pytree representing the server model parameters. opt_state: A pytree representing the server optimizer state. """ params: Params opt_state: optimizers.OptState
[docs]def federated_averaging( grad_fn: Callable[[Params, BatchExample, PRNGKey], Grads], client_optimizer: optimizers.Optimizer, server_optimizer: optimizers.Optimizer, client_batch_hparams: client_datasets.ShuffleRepeatBatchHParams ) -> federated_algorithm.FederatedAlgorithm: """Builds federated averaging. Args: grad_fn: A function from (params, batch_example, rng) to gradients. This can be created with :func:`fedjax.core.model.model_grad`. client_optimizer: Optimizer for local client training. server_optimizer: Optimizer for server update. client_batch_hparams: Hyperparameters for batching client dataset for train. Returns: FederatedAlgorithm """ train_for_each_client = create_train_for_each_client(grad_fn, client_optimizer) def init(params: Params) -> ServerState: opt_state = server_optimizer.init(params) return ServerState(params, opt_state) def apply( server_state: ServerState, clients: Sequence[Tuple[federated_data.ClientId, client_datasets.ClientDataset, PRNGKey]] ) -> Tuple[ServerState, Mapping[federated_data.ClientId, Any]]: client_num_examples = {cid: len(cds) for cid, cds, _ in clients} batch_clients = [(cid, cds.shuffle_repeat_batch(client_batch_hparams), crng) for cid, cds, crng in clients] client_diagnostics = {} # Running weighted mean of client updates. We do this iteratively to avoid # loading all the client outputs into memory since they can be prohibitively # large depending on the model parameters size. delta_params_sum = tree_util.tree_zeros_like(server_state.params) num_examples_sum = 0. for client_id, delta_params in train_for_each_client( server_state.params, batch_clients): num_examples = client_num_examples[client_id] delta_params_sum = tree_util.tree_add( delta_params_sum, tree_util.tree_weight(delta_params, num_examples)) num_examples_sum += num_examples # We record the l2 norm of client updates as an example, but it is not # required for the algorithm. client_diagnostics[client_id] = { 'delta_l2_norm': tree_util.tree_l2_norm(delta_params) } mean_delta_params = tree_util.tree_inverse_weight(delta_params_sum, num_examples_sum) server_state = server_update(server_state, mean_delta_params) return server_state, client_diagnostics def server_update(server_state, mean_delta_params): opt_state, params = server_optimizer.apply(mean_delta_params, server_state.opt_state, server_state.params) return ServerState(params, opt_state) return federated_algorithm.FederatedAlgorithm(init, apply)