fedjax.algorithms
Federated learning algorithm implementations.
AgnosticFedAvg implementation. |
|
Federated averaging implementation using fedjax.core. |
|
Federated hypothesis-based clustering implementation. |
|
Mime implementation. |
|
Mime Lite implementation. |
AgnosticFedAvg
AgnosticFedAvg implementation.
- Communication-Efficient Agnostic Federated Averaging
Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748
- class fedjax.algorithms.agnostic_fed_avg.ServerState(params, opt_state, domain_weights, domain_window)[source]
State of server for AgnosticFedAvg passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the server optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- domain_window
Sliding window keeping track of domain number of examples for the last window size rounds.
- Type:
List[jax.Array]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.agnostic_fed_avg.agnostic_federated_averaging(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, domain_batch_hparams, init_domain_weights, domain_learning_rate, domain_algorithm='eg', domain_window_size=1, init_domain_window=None, regularizer=None)[source]
Builds agnostic federated averaging.
Agnostic federated averaging requires input
fedjax.core.client_datasets.ClientDataset
examples to contain a feature named “domain_id”, which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1).- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for client dataset for training.domain_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for client dataset domain metrics calculation.init_domain_weights (
Sequence
[float
]) – Initial weights per domain that must sum to 1.domain_learning_rate (
float
) – Learning rate for domain weight update.domain_algorithm (
str
) – Algorithm used to update domain weights each round. One of ‘eg’, ‘none’.domain_window_size (
int
) – Size of sliding window keeping track of number of examples per domain over multiple rounds.init_domain_window (
Optional
[Sequence
[float
]]) – Initial values for domain window. Defaults to ones.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type:
- Returns:
FederatedAlgorithm.
- Raises:
ValueError – If
init_domain_weights
does not sum to 1 or ifinit_domain_weights
andinit_domain_window
are unequal lengths.
FedAvg
Federated averaging implementation using fedjax.core.
This is the more performant implementation that matches what would be used in
the fedjax.algorithms.fed_avg
. The key difference between this and
the basic version is the use of fedjax.core.for_each_client
- Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021. https://arxiv.org/abs/2003.00295
- class fedjax.algorithms.fed_avg.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the server optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams)[source]
Builds federated averaging.
- Parameters:
grad_fn (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Any
]) – A function from (params, batch_example, rng) to gradients. This can be created withfedjax.core.model.model_grad()
.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.
- Return type:
- Returns:
FederatedAlgorithm
HypCluster
Federated hypothesis-based clustering implementation.
- Three Approaches for Personalization with Applications to Federated Learning
Yishay Mansour, Mehryar Mohri, Jae Ro, Ananda Theertha Suresh https://arxiv.org/abs/2002.10619
- class fedjax.algorithms.hyp_cluster.ServerState(cluster_params, opt_states)[source]
State of server for HypCluster passed between rounds.
- cluster_params
A list of pytrees representing the server model parameters per cluster.
- Type:
List[Any]
- opt_states
A list of pytrees representing the server optimizer state per cluster.
- Type:
List[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.hyp_cluster.hyp_cluster(per_example_loss, client_optimizer, server_optimizer, maximization_batch_hparams, expectation_batch_hparams, regularizer=None)[source]
Federated hypothesis-based clustering algorithm.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch, rng) to a vector of per example loss values.client_optimizer (
Optimizer
) – Client side optimizer.server_optimizer (
Optimizer
) – Server side optimizer.maximization_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for the maximization step.expectation_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for the expectation step.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer.
- Return type:
- Returns:
A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that init() takes a list of cluster_params, which can be obtained using random_init(), ModelKMeansInitializer, or kmeans_init() in this module.
- fedjax.algorithms.hyp_cluster.kmeans_init(num_clusters, init_params, clients, trainer, train_batch_hparams, evaluator, eval_batch_hparams, rng)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
See
ModelKMeansInitializer
for a more convenient initializer from aModel
.Given a set of input clients, we train parameters for each client. The initial cluster parameters are chosen out of this set of client parameters. At a high level, the cluster parameters are selected by choosing the client parameters that are the “furthest away” from each other (greatest difference in loss).
- Parameters:
num_clusters (
int
) – Desired number of cluster centers.init_params (
Any
) – Initial params for training each client.clients (
Sequence
[Tuple
[bytes
,ClientDataset
,Array
]]) – Clients to train for generating candidate cluster center params.trainer (
ClientParamsTrainer
) – Client trainer for carry out client training.train_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for client training.evaluator (
AverageLossEvaluator
) – Average loss evaluator for evaluator clients on chosen cluster center params.eval_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for client evaluation.rng (
Array
) – RNG used for choosing the initial cluster center.
- Return type:
List
[Any
]- Returns:
Cluster center params.
- Raises:
ValueError if some input arguments are invalid. –
- fedjax.algorithms.hyp_cluster.random_init(num_clusters, init, rng)[source]
Randomly initializes cluster params.
- Return type:
List
[Any
]
- class fedjax.algorithms.hyp_cluster.ModelKMeansInitializer(model, client_optimizer, regularizer=None)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
This is a thin wrapper for initializing from a
Model
. Seekmeans_init()
for a more general version of the initializer, and details of the initialization algorithm.
- class fedjax.algorithms.hyp_cluster.HypClusterEvaluator(model, regularizer=None)[source]
Evaluates cluster params on a Model.
- __init__(model, regularizer=None)[source]
Initializes some reusable components.
Because we need to make multiple passes over each client during evaluation, the number of clients that can be evaluated at once is limited. Therefore multiple calls to
evaluate_clients()
are needed to evaluate a large federated dataset. We factor out some reusable components so that the same computation can be jit compiled.
Mime
Mime implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
- class fedjax.algorithms.mime.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type:
Any
- opt_state
A pytree representing the base optimizer state.
- Type:
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.mime.mime(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None)[source]
Builds mime.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type:
- Returns:
FederatedAlgorithm
MimeLite
Mime Lite implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
Reuses fedjax.algorithms.mime.ServerState
- fedjax.algorithms.mime_lite.mime_lite(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None, client_delta_clip_norm=None)[source]
Builds mime lite.
- Parameters:
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.client_delta_clip_norm (
Optional
[float
]) – Maximum allowed global norm per client update. Defaults to no clipping.
- Return type:
- Returns:
FederatedAlgorithm