fedjax.aggregators
FedJAX aggregators.
- class fedjax.aggregators.Aggregator(init, apply)[source]
Interface for algorithms to aggregate.
This interface defines aggregator algorithms that are used at each round. Aggregator state contains any round specific parameters (e.g. number of bits) that will be passed from round to round. This state is initialized by init and passed as input into and returned as output from aggregate. We strongly recommend using fedjax.dataclass to define state as this provides immutability, type hinting, and works by default with JAX transformations.
The expected usage of Aggregator is as follows:
aggregator = mean_aggregator() state = aggregator.init() for i in range(num_rounds): clients_params_and_weights = compute_client_outputs(i) aggregated_params, state = aggregator.apply(clients_params_and_weights, state)
- init
Returns initial state of aggregator.
- Type:
Callable[[], Any]
- apply
Returns the new aggregator state and aggregated params.
- Type:
Callable[[Iterable[Tuple[bytes, Any, float]], Any], Tuple[Any, Any]]
Mean
Quantization
- fedjax.aggregators.uniform_stochastic_quantizer(num_levels, rng, encode_algorithm=None)[source]
Returns (weighted) mean of input uniformly quantized trees using the
uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.
- Parameters:
num_levels (
int
) – number of levels of quantization.rng (
Array
) – PRNGKey used for compression:encode_algorithm (
Optional
[str
]) – None or arithmetic
- Return type:
- Returns:
Compression aggregator.