FedJAX documentation
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX prioritizes ease-of-use and is intended to be useful for anyone with knowledge of NumPy.
Developer documentation
API reference
- fedjax core
- Subpackages
- Federated algorithm
- Federated data
- Client dataset
- For each client
- Model
FederatedAlgorithm
FederatedData
SubsetFederatedData
SQLiteFederatedData
InMemoryFederatedData
FederatedDataBuilder
SQLiteFederatedDataBuilder
ClientPreprocessor
shuffle_repeat_batch_federated_data()
padded_batch_federated_data()
RepeatableIterator
ClientDataset
BatchPreprocessor
buffered_shuffle_batch_client_datasets()
padded_batch_client_datasets()
for_each_client()
for_each_client_backend()
set_for_each_client_backend()
Model
create_model_from_haiku()
create_model_from_stax()
evaluate_model()
model_grad()
model_per_example_loss()
evaluate_average_loss()
ModelEvaluator
AverageLossEvaluator
grad()
- fedjax.aggregators
- fedjax.algorithms
- fedjax.datasets
- fedjax.models
- fedjax.training