- class fedjax.aggregators.Aggregator(init, apply)
Interface for algorithms to aggregate.
This interface defines aggregator algorithms that are used at each round. Aggregator state contains any round specific parameters (e.g. number of bits) that will be passed from round to round. This state is initialized by init and passed as input into and returned as output from aggregate. We strongly recommend using fedjax.dataclass to define state as this provides immutability, type hinting, and works by default with JAX transformations.
The expected usage of Aggregator is as follows:
aggregator = mean_aggregator() state = aggregator.init() for i in range(num_rounds): clients_params_and_weights = compute_client_outputs(i) aggregated_params, state = aggregator.apply(clients_params_and_weights, state)
Returns initial state of aggregator.
Returns the new aggregator state and aggregated params.
Callable[[Iterable[Tuple[bytes, Any, float]], Any], Tuple[Any, Any]]
- fedjax.aggregators.uniform_stochastic_quantizer(num_levels, rng, encode_algorithm=None)
Returns (weighted) mean of input uniformly quantized trees using the
uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.
int) – number of levels of quantization.
Array) – PRNGKey used for compression:
str]) – None or arithmetic
- Return type