fedjax.algorithms
Federated learning algorithm implementations.
AgnosticFedAvg implementation. |
|
Federated averaging implementation using fedjax.core. |
|
Federated hypothesis-based clustering implementation. |
|
Mime implementation. |
|
Mime Lite implementation. |
AgnosticFedAvg
AgnosticFedAvg implementation.
- Communication-Efficient Agnostic Federated Averaging
Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748
- class fedjax.algorithms.agnostic_fed_avg.ServerState(params, opt_state, domain_weights, domain_window)[source]
State of server for AgnosticFedAvg passed between rounds.
- params
A pytree representing the server model parameters.
- Type
Any
- opt_state
A pytree representing the server optimizer state.
- Type
Union[jax._src.basearray.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- domain_weights
Weights per domain applied to weight in weighted average.
- Type
jax._src.basearray.Array
- domain_window
Sliding window keeping track of domain number of examples for the last window size rounds.
- Type
List[jax._src.basearray.Array]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.agnostic_fed_avg.agnostic_federated_averaging(per_example_loss, client_optimizer, server_optimizer, client_batch_hparams, domain_batch_hparams, init_domain_weights, domain_learning_rate, domain_algorithm='eg', domain_window_size=1, init_domain_window=None, regularizer=None)[source]
Builds agnostic federated averaging.
Agnostic federated averaging requires input
fedjax.core.client_datasets.ClientDataset
examples to contain a feature named “domain_id”, which stores the integer domain id in [0, num_domains). For example, for Stack Overflow, each example post can be either a question or an answer, so there are two possible domain ids (question = 0; answer = 1).- Parameters
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the domain metrics computation and gradient descent training.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for client dataset for training.domain_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for client dataset domain metrics calculation.init_domain_weights (
Sequence
[float
]) – Initial weights per domain that must sum to 1.domain_learning_rate (
float
) – Learning rate for domain weight update.domain_algorithm (
str
) – Algorithm used to update domain weights each round. One of ‘eg’, ‘none’.domain_window_size (
int
) – Size of sliding window keeping track of number of examples per domain over multiple rounds.init_domain_window (
Optional
[Sequence
[float
]]) – Initial values for domain window. Defaults to ones.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type
- Returns
FederatedAlgorithm.
- Raises
ValueError – If
init_domain_weights
does not sum to 1 or ifinit_domain_weights
andinit_domain_window
are unequal lengths.
FedAvg
Federated averaging implementation using fedjax.core.
This is the more performant implementation that matches what would be used in
the fedjax.algorithms.fed_avg
. The key difference between this and
the basic version is the use of fedjax.core.for_each_client
- Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. ICLR 2021. https://arxiv.org/abs/2003.00295
- class fedjax.algorithms.fed_avg.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type
Any
- opt_state
A pytree representing the server optimizer state.
- Type
Union[jax._src.basearray.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer, server_optimizer, client_batch_hparams)[source]
Builds federated averaging.
- Parameters
grad_fn (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Any
]) – A function from (params, batch_example, rng) to gradients. This can be created withfedjax.core.model.model_grad()
.client_optimizer (
Optimizer
) – Optimizer for local client training.server_optimizer (
Optimizer
) – Optimizer for server update.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.
- Return type
- Returns
FederatedAlgorithm
HypCluster
Federated hypothesis-based clustering implementation.
- Three Approaches for Personalization with Applications to Federated Learning
Yishay Mansour, Mehryar Mohri, Jae Ro, Ananda Theertha Suresh https://arxiv.org/abs/2002.10619
- class fedjax.algorithms.hyp_cluster.ServerState(cluster_params, opt_states)[source]
State of server for HypCluster passed between rounds.
- cluster_params
A list of pytrees representing the server model parameters per cluster.
- Type
List[Any]
- opt_states
A list of pytrees representing the server optimizer state per cluster.
- Type
List[Union[jax._src.basearray.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.hyp_cluster.hyp_cluster(per_example_loss, client_optimizer, server_optimizer, maximization_batch_hparams, expectation_batch_hparams, regularizer=None)[source]
Federated hypothesis-based clustering algorithm.
- Parameters
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch, rng) to a vector of per example loss values.client_optimizer (
Optimizer
) – Client side optimizer.server_optimizer (
Optimizer
) – Server side optimizer.maximization_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for the maximization step.expectation_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for the expectation step.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer.
- Return type
- Returns
A FederatedAlgorithm. A notable difference from other common FederatedAlgorithms such as fed_avg is that init() takes a list of cluster_params, which can be obtained using random_init(), ModelKMeansInitializer, or kmeans_init() in this module.
- fedjax.algorithms.hyp_cluster.kmeans_init(num_clusters, init_params, clients, trainer, train_batch_hparams, evaluator, eval_batch_hparams, rng)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
See
ModelKMeansInitializer
for a more convenient initializer from aModel
.Given a set of input clients, we train parameters for each client. The initial cluster parameters are chosen out of this set of client parameters. At a high level, the cluster parameters are selected by choosing the client parameters that are the “furthest away” from each other (greatest difference in loss).
- Parameters
num_clusters (
int
) – Desired number of cluster centers.init_params (
Any
) – Initial params for training each client.clients (
Sequence
[Tuple
[bytes
,ClientDataset
,Array
]]) – Clients to train for generating candidate cluster center params.trainer (
ClientParamsTrainer
) – Client trainer for carry out client training.train_batch_hparams (
ShuffleRepeatBatchHParams
) – Batching hyperparameters for client training.evaluator (
AverageLossEvaluator
) – Average loss evaluator for evaluator clients on chosen cluster center params.eval_batch_hparams (
PaddedBatchHParams
) – Batching hyperparameters for client evaluation.rng (
Array
) – RNG used for choosing the initial cluster center.
- Return type
List
[Any
]- Returns
Cluster center params.
- Raises
ValueError if some input arguments are invalid. –
- fedjax.algorithms.hyp_cluster.random_init(num_clusters, init, rng)[source]
Randomly initializes cluster params.
- Return type
List
[Any
]
- class fedjax.algorithms.hyp_cluster.ModelKMeansInitializer(model, client_optimizer, regularizer=None)[source]
Initializes cluster params for HypCluster using a k-means++ variant.
This is a thin wrapper for initializing from a
Model
. Seekmeans_init()
for a more general version of the initializer, and details of the initialization algorithm.
- class fedjax.algorithms.hyp_cluster.HypClusterEvaluator(model, regularizer=None)[source]
Evaluates cluster params on a Model.
- __init__(model, regularizer=None)[source]
Initializes some reusable components.
Because we need to make multiple passes over each client during evaluation, the number of clients that can be evaluated at once is limited. Therefore multiple calls to
evaluate_clients()
are needed to evaluate a large federated dataset. We factor out some reusable components so that the same computation can be jit compiled.- Parameters
model (
Model
) – Model being evaluated.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer.
Mime
Mime implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
- class fedjax.algorithms.mime.ServerState(params, opt_state)[source]
State of server passed between rounds.
- params
A pytree representing the server model parameters.
- Type
Any
- opt_state
A pytree representing the base optimizer state.
- Type
Union[jax._src.basearray.Array, Iterable[ArrayTree], Mapping[Any, ArrayTree]]
- replace(**updates)
“Returns a new object replacing the specified fields with new values.
- fedjax.algorithms.mime.mime(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None)[source]
Builds mime.
- Parameters
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.
- Return type
- Returns
FederatedAlgorithm
MimeLite
Mime Lite implementation.
- Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning
Sai Praneeth Karimireddy, Martin Jaggi, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh https://arxiv.org/abs/2008.03606
Reuses fedjax.algorithms.mime.ServerState
- fedjax.algorithms.mime_lite.mime_lite(per_example_loss, base_optimizer, client_batch_hparams, grads_batch_hparams, server_learning_rate, regularizer=None, client_delta_clip_norm=None)[source]
Builds mime lite.
- Parameters
per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – A function from (params, batch_example, rng) to a vector of loss values for each example in the batch. This is used in both the server gradient computation and gradient descent training.base_optimizer (
Optimizer
) – Base optimizer to mimic.client_batch_hparams (
ShuffleRepeatBatchHParams
) – Hyperparameters for batching client dataset for train.grads_batch_hparams (
PaddedBatchHParams
) – Hyperparameters for batching client dataset for server gradient computation.server_learning_rate (
float
) – Server learning rate.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer that only depends on params.client_delta_clip_norm (
Optional
[float
]) – Maximum allowed global norm per client update. Defaults to no clipping.
- Return type
- Returns
FederatedAlgorithm