fedjax.aggregators

FedJAX aggregators.

class fedjax.aggregators.Aggregator(init, apply)[source]

Interface for algorithms to aggregate.

This interface defines aggregator algorithms that are used at each round. Aggregator state contains any round specific parameters (e.g. number of bits) that will be passed from round to round. This state is initialized by init and passed as input into and returned as output from aggregate. We strongly recommend using fedjax.dataclass to define state as this provides immutability, type hinting, and works by default with JAX transformations.

The expected usage of Aggregator is as follows:

aggregator = mean_aggregator()
state = aggregator.init()
for i in range(num_rounds):
  clients_params_and_weights = compute_client_outputs(i)
  aggregated_params, state = aggregator.apply(clients_params_and_weights,
                                              state)
init

Returns initial state of aggregator.

Type:

Callable[[], Any]

apply

Returns the new aggregator state and aggregated params.

Type:

Callable[[Iterable[Tuple[bytes, Any, float]], Any], Tuple[Any, Any]]

Mean

fedjax.aggregators.mean_aggregator()[source]

Builds (weighted) mean aggregator.

Return type:

Aggregator

Quantization

fedjax.aggregators.uniform_stochastic_quantizer(num_levels, rng, encode_algorithm=None)[source]

Returns (weighted) mean of input uniformly quantized trees using the

uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.

Parameters:
  • num_levels (int) – number of levels of quantization.

  • rng (Array) – PRNGKey used for compression:

  • encode_algorithm (Optional[str]) – None or arithmetic

Return type:

Aggregator

Returns:

Compression aggregator.