fedjax.training

FedJAX training utilities.

Federated experiment

fedjax.training.run_federated_experiment(algorithm, init_state, client_sampler, config, periodic_eval_fn_map=None, final_eval_fn_map=None)[source]

Runs the training loop of a federated algorithm experiment.

Parameters:
  • algorithm (FederatedAlgorithm) – Federated algorithm to use.

  • init_state (Any) – Initial server state.

  • client_sampler (ClientSampler) – Sampler for training clients.

  • config (FederatedExperimentConfig) – FederatedExperimentConfig configurations.

  • periodic_eval_fn_map (Optional[Mapping[str, Any]]) – Mapping of name to evaluation functions that are run repeatedly over multiple federated training rounds. The frequency is defined in _FederatedExperimentConfig.eval_frequency.

  • final_eval_fn_map (Optional[Mapping[str, EvaluationFn]]) – Mapping of name to evaluation functions that are run at the very end of federated training. Typically, full test evaluation functions will be set here.

Return type:

Any

Returns:

Final state of the input federated algortihm after training.

class fedjax.training.FederatedExperimentConfig(root_dir: str, num_rounds: int, checkpoint_frequency: int = 0, num_checkpoints_to_keep: int = 1, eval_frequency: int = 0)[source]

Common configurations of a federated experiment.

Attribues:

root_dir: Root directory for experiment outputs (e.g. metrics). num_rounds: Number of federated training rounds. checkpoint_frequency: Checkpoint frequency in rounds. If <= 0, no checkpointing is done. num_checkpoints_to_keep: Maximum number of checkpoints to keep. eval_frequency: Evaluation frequency in rounds. If <= 0, no evaluation is done.

class fedjax.training.EvaluationFn[source]

Evaluation function that are only fed state at every call.

Typically used for full evaluation or evaluation on sampled clients from a test set.

abstract __call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

class fedjax.training.ModelFullEvaluationFn(fd, model, batch_hparams)[source]

Bases: EvaluationFn

Evaluation on an entire federated dataset using the centralized model.

__call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

__init__(fd, model, batch_hparams)[source]
class fedjax.training.ModelSampleClientsEvaluationFn(client_sampler, model, batch_hparams)[source]

Bases: EvaluationFn

Evaluation on sampled clients using the centralized model.

The state to be evaluated must contain a params field.

__call__(state, round_num)[source]

Runs final evaluation.

Return type:

Mapping[str, Array]

__init__(client_sampler, model, batch_hparams)[source]
class fedjax.training.TrainClientsEvaluationFn[source]

Evaluation function that are fed training clients at every call.

Typically used for evaluation on the training clients used in a step.

abstract __call__(state, round_num, train_clients)[source]

Runs evaluation.

Return type:

Mapping[str, Array]

class fedjax.training.ModelTrainClientsEvaluationFn(model, batch_hparams)[source]

Bases: TrainClientsEvaluationFn

Evaluation on training clients using the centralized model.

The state to be evaluated must contain a params field.

__call__(state, round_num, train_clients)[source]

Runs evaluation.

Return type:

Mapping[str, Array]

__init__(model, batch_hparams)[source]
fedjax.training.set_tf_cpu_only()[source]

Restricts TensorFlow device visibility to only CPU.

TensorFlow is only used for data loading, so we prevent it from allocating GPU/TPU memory.

fedjax.training.load_latest_checkpoint(root_dir)[source]

Loads latest checkpoint and round number.

Return type:

Optional[Tuple[Any, int]]

fedjax.training.save_checkpoint(root_dir, state, round_num=0, keep=1)[source]

Saves checkpoint and cleans up old checkpoints.

class fedjax.training.Logger(root_dir=None)[source]

Class to encapsulate tf.summary.SummaryWriter logging logic.

__init__(root_dir=None)[source]

Initializes summary writers and log directory.

log(writer_name, metric_name, metric_value, round_num)[source]

Records metric using specified summary writer.

Logs at INFO verbosity. Also, if root_dir is set and metric_value is: - a scalar value, convertible to a float32 Tensor, writes scalar summary - a vector, convertible to a float32 Tensor, writes histogram summary

Parameters:
  • writer_name (str) – Name of summary writer.

  • metric_name (str) – Name of metric to log.

  • metric_value (Any) – Value of metric to log.

  • round_num (int) – Round number to log.

Tasks

Registry of standard tasks.

Each task is represented as a (train federated data, test federated data, model) tuple.

training.ALL_TASKS = ('EMNIST_CONV', 'EMNIST_LOGISTIC', 'EMNIST_DENSE', 'SHAKESPEARE_CHARACTER', 'STACKOVERFLOW_WORD', 'CIFAR100_LOGISTIC')
fedjax.training.get_task(name, mode='sqlite', cache_dir=None)[source]

Gets a standard task.

Parameters:
  • name (str) – Name of the task to get. Must be one of fedjax.training.ALL_TASKS.

  • mode (str) – ‘sqlite’.

  • cache_dir (Optional[str]) – Directory to cache files in ‘sqlite’ mode.

Return type:

Tuple[FederatedData, FederatedData, Model]

Returns:

(train federated data, test federated data, model) tuple.

Structured flags

Structured flags commonly used in experiment binaries.

Structured flags are often used to construct complex structures via multiple simple flags (e.g. an optimizer can be created by controlling learning rate and other hyper parameters).

class fedjax.training.structured_flags.BatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs BatchHParams from flags.

class fedjax.training.structured_flags.FederatedExperimentConfigFlags(name=None)[source]

Constructs FederatedExperimentConfig from flags.

class fedjax.training.structured_flags.NamedFlags(name)[source]

A group of flags with an optional named prefix.

class fedjax.training.structured_flags.OptimizerFlags(name=None, default_optimizer='sgd')[source]

Constructs a fedjax.Optimizer from flags.

get()[source]

Gets the specified optimizer.

Return type:

Optimizer

class fedjax.training.structured_flags.PaddedBatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs PaddedBatchHParams from flags.

class fedjax.training.structured_flags.ShuffleRepeatBatchHParamsFlags(name=None, default_batch_size=128)[source]

Constructs ShuffleRepeatBatchHParams from flags.

class fedjax.training.structured_flags.TaskFlags(name=None)[source]

Constructs a standard task tuple from flags.