fedjax.algorithms

Federated learning algorithm implementations.

fedjax.algorithms.agnostic_fed_avg

AgnosticFedAvg implementation.

fedjax.algorithms.fed_avg

Federated averaging implementation using fedjax.core.

fedjax.algorithms.hyp_cluster

Federated hypothesis-based clustering implementation.

fedjax.algorithms.mime

Mime implementation.

fedjax.algorithms.mime_lite

Mime Lite implementation.

AgnosticFedAvg

AgnosticFedAvg implementation.

Communication-Efficient Agnostic Federated Averaging

Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748

class fedjax.algorithms.agnostic_fed_avg.ServerState(params, opt_state, domain_weights, domain_window)[source]

State of server for AgnosticFedAvg passed between rounds.

params

A pytree representing the server model parameters.

Type

Any

opt_state

A pytree representing the server optimizer state.

Type

Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

domain_weights

Weights per domain applied to weight in weighted average.

Type

jax._src.numpy.lax_numpy.ndarray

domain_window

Sliding window keeping track of domain number of examples for the last window size rounds.

Type

List[jax._src.numpy.lax_numpy.ndarray]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.agnostic_fed_avg.agnostic_federated_averaging(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, domain_batch_hparams, init_domain_weights, domain_learning_rate, domain_algorithm='eg', domain_window_size=1, init_domain_window=None, regularizer=None)[source]

Builds agnostic federated averaging.

Agnostic federated averaging requires input fedjax.core.client_datasets.ClientDataset examples to contain a feature named “domain_id”, which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1).

Parameters
  • per_example_loss (Callable[[Any, Mapping[str, ndarray], ndarray], ndarray]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training.

  • client_optimizer (Optimizer) – Optimizer for local client training.

  • server_optimizer (Optimizer) – Optimizer for server update.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for client dataset for training.

  • domain_batch_hparams (PaddedBatchHParams) – Hyperparameters for client dataset domain metrics calculation.

  • init_domain_weights (Sequence[float]) – Initial weights per domain that must sum to 1.

  • domain_learning_rate (float) – Learning rate for domain weight update.

  • domain_algorithm (str) – Algorithm used to update domain weights each round. One of ‘eg’, ‘none’.

  • domain_window_size (int) – Size of sliding window keeping track of number of examples per domain over multiple rounds.

  • init_domain_window (Optional[Sequence[float]]) – Initial values for domain window. Defaults to ones.

  • regularizer (Optional[Callable[[Any], ndarray]]) – Optional regularizer that only depends on params.

Return type

FederatedAlgorithm

Returns

FederatedAlgorithm.

Raises

ValueError – If init_domain_weights does not sum to 1 or if init_domain_weights and init_domain_window are unequal lengths.

FedAvg

Federated averaging implementation using fedjax.core.

This is the more performant implementation that matches what would be used in the fedjax.algorithms.fed_avg . The key difference between this and the basic version is the use of fedjax.core.for_each_client

Communication-Efficient Learning of Deep Networks from Decentralized Data

H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021. https://arxiv.org/abs/2003.00295

class fedjax.algorithms.fed_avg.ServerState(params, opt_state)[source]

State of server passed between rounds.

params

A pytree representing the server model parameters.

Type

Any

opt_state

A pytree representing the server optimizer state.

Type

Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams)[source]

Builds federated averaging.

Parameters
  • grad_fn (Callable[[Any, Mapping[str, ndarray], ndarray], Any]) – A function from (params, batch_example, rng) to gradients. This can be created with fedjax.core.model.model_grad().

  • client_optimizer (Optimizer) – Optimizer for local client training.

  • server_optimizer (Optimizer) – Optimizer for server update.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

Return type

FederatedAlgorithm

Returns

FederatedAlgorithm

HypCluster

Federated hypothesis-based clustering implementation.

Three Approaches for Personalization with Applications to Federated Learning

Yishay Mansour, Mehryar Mohri, Jae Ro, Ananda Theertha Suresh https://arxiv.org/abs/2002.10619

class fedjax.algorithms.hyp_cluster.ServerState(cluster_params, opt_states)[source]

State of server for HypCluster passed between rounds.

cluster_params

A list of pytrees representing the server model parameters per cluster.

Type

List[Any]

opt_states

A list of pytrees representing the server optimizer state per cluster.

Type

List[Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.hyp_cluster.hyp_cluster(per_example_loss, client_optimizer, server_optimizer, maximization_batch_hparams, expectation_batch_hparams, regularizer=None)[source]

Federated hypothesis-based clustering algorithm.

Parameters
  • per_example_loss (Callable[[Any, Mapping[str, ndarray], ndarray], ndarray]) – A function from (params, batch, rng) to a vector of per example loss values.

  • client_optimizer (Optimizer) – Client side optimizer.

  • server_optimizer (Optimizer) – Server side optimizer.

  • maximization_batch_hparams (PaddedBatchHParams) – Batching hyperparameters for the maximization step.

  • expectation_batch_hparams (ShuffleRepeatBatchHParams) – Batching hyperparameters for the expectation step.

  • regularizer (Optional[Callable[[Any], ndarray]]) – Optional regularizer.

Return type

FederatedAlgorithm

Returns

A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that init() takes a list of cluster_params, which can be obtained using random_init(), ModelKMeansInitializer, or kmeans_init() in this module.

fedjax.algorithms.hyp_cluster.kmeans_init(num_clusters, init_params, clients, trainer, train_batch_hparams, evaluator, eval_batch_hparams, rng)[source]

Initializes cluster params for HypCluster using a k-means++ variant.

See ModelKMeansInitializer for a more convenient initializer from a Model.

Given a set of input clients, we train parameters for each client. The initial cluster parameters are chosen out of this set of client parameters. At a high level, the cluster parameters are selected by choosing the client parameters that are the “furthest away” from each other (greatest difference in loss).

Parameters
  • num_clusters (int) – Desired number of cluster centers.

  • init_params (Any) – Initial params for training each client.

  • clients (Sequence[Tuple[bytes, ClientDataset, ndarray]]) – Clients to train for generating candidate cluster center params.

  • trainer (ClientParamsTrainer) – Client trainer for carry out client training.

  • train_batch_hparams (ShuffleRepeatBatchHParams) – Batching hyperparameters for client training.

  • evaluator (AverageLossEvaluator) – Average loss evaluator for evaluator clients on chosen cluster center params.

  • eval_batch_hparams (PaddedBatchHParams) – Batching hyperparameters for client evaluation.

  • rng (ndarray) – RNG used for choosing the initial cluster center.

Return type

List[Any]

Returns

Cluster center params.

Raises

ValueError if some input arguments are invalid.

fedjax.algorithms.hyp_cluster.random_init(num_clusters, init, rng)[source]

Randomly initializes cluster params.

Return type

List[Any]

class fedjax.algorithms.hyp_cluster.ModelKMeansInitializer(model, client_optimizer, regularizer=None)[source]

Initializes cluster params for HypCluster using a k-means++ variant.

This is a thin wrapper for initializing from a Model . See kmeans_init() for a more general version of the initializer, and details of the initialization algorithm.

__init__(model, client_optimizer, regularizer=None)[source]

Initialize self. See help(type(self)) for accurate signature.

cluster_params(num_clusters, rng, clients, train_batch_hparams, eval_batch_hparams)[source]
Return type

List[Any]

class fedjax.algorithms.hyp_cluster.HypClusterEvaluator(model, regularizer=None)[source]

Evaluates cluster params on a Model.

__init__(model, regularizer=None)[source]

Initializes some reusable components.

Because we need to make multiple passes over each client during evaluation, the number of clients that can be evaluated at once is limited. Therefore multiple calls to evaluate_clients() are needed to evaluate a large federated dataset. We factor out some reusable components so that the same computation can be jit compiled.

Parameters
  • model (Model) – Model being evaluated.

  • regularizer (Optional[Callable[[Any], ndarray]]) – Optional regularizer.

evaluate_clients(cluster_params, train_clients, test_clients, batch_hparams)[source]

Evaluates each client on the cluster with best average loss.

Return type

Iterator[Tuple[bytes, Dict[str, ndarray]]]

Mime

Mime implementation.

Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning

Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606

class fedjax.algorithms.mime.ServerState(params, opt_state)[source]

State of server passed between rounds.

params

A pytree representing the server model parameters.

Type

Any

opt_state

A pytree representing the base optimizer state.

Type

Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

replace(**updates)

“Returns a new object replacing the specified fields with new values.

fedjax.algorithms.mime.mime(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None)[source]

Builds mime.

Parameters
  • per_example_loss (Callable[[Any, Mapping[str, ndarray], ndarray], ndarray]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.

  • base_optimizer (Optimizer) – Base optimizer to mimic.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

  • grads_batch_hparams (PaddedBatchHParams) – Hyperparameters for batching client dataset for server gradient computation.

  • server_learning_rate (float) – Server learning rate.

  • regularizer (Optional[Callable[[Any], ndarray]]) – Optional regularizer that only depends on params.

Return type

FederatedAlgorithm

Returns

FederatedAlgorithm

MimeLite

Mime Lite implementation.

Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning

Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606

Reuses fedjax.algorithms.mime.ServerState

fedjax.algorithms.mime_lite.mime_lite(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None, client_delta_clip_norm=None)[source]

Builds mime lite.

Parameters
  • per_example_loss (Callable[[Any, Mapping[str, ndarray], ndarray], ndarray]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.

  • base_optimizer (Optimizer) – Base optimizer to mimic.

  • client_batch_hparams (ShuffleRepeatBatchHParams) – Hyperparameters for batching client dataset for train.

  • grads_batch_hparams (PaddedBatchHParams) – Hyperparameters for batching client dataset for server gradient computation.

  • server_learning_rate (float) – Server learning rate.

  • regularizer (Optional[Callable[[Any], ndarray]]) – Optional regularizer that only depends on params.

  • client_delta_clip_norm (Optional[float]) – Maximum allowed global norm per client update. Defaults to no clipping.

Return type

FederatedAlgorithm

Returns

FederatedAlgorithm