fedjax.models
Model implementations for FedJAX experimental API.
EMNIST models. |
|
Shakespeare recurrent models. |
|
Stack Overflow recurrent models. |
|
Toy regression models. |
EMNIST
EMNIST models.
- class fedjax.models.emnist.ConvDropoutModule(num_classes)[source]
Bases:
Module
Custom haiku module for CNN with dropout.
This must be defined as a custom hk.Module because only a single positional argument is allowed when using hk.Sequential.
- fedjax.models.emnist.create_conv_model(only_digits=False)[source]
Creates EMNIST CNN model with dropout with haiku.
Matches the model used in:
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295
- Parameters:
only_digits (
bool
) – Whether to use only digit classes [0-9] or include lower and upper case characters for a total of 62 classes.- Return type:
- Returns:
Model
- fedjax.models.emnist.create_dense_model(only_digits=False, hidden_units=200)[source]
Creates EMNIST dense net with haiku.
- Return type:
Shakespeare
Shakespeare recurrent models.
- fedjax.models.shakespeare.create_lstm_model(vocab_size=86, embed_size=8, lstm_hidden_size=256, lstm_num_layers=2)[source]
Creates LSTM language model.
Character-level LSTM for Shakespeare language model. Defaults to the model used in:
- Communication-Efficient Learning of Deep Networks from Decentralized Data
H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629
- Parameters:
vocab_size (
int
) – The number of possible output characters. This does not include special tokens like PAD, BOS, EOS, or OOV.embed_size (
int
) – Embedding size for each character.lstm_hidden_size (
int
) – Hidden size for LSTM cells.lstm_num_layers (
int
) – Number of LSTM layers.
- Return type:
- Returns:
Model.
Stack Overflow
Stack Overflow recurrent models.
- fedjax.models.stackoverflow.create_lstm_model(vocab_size=10000, embed_size=96, lstm_hidden_size=670, lstm_num_layers=1, share_input_output_embeddings=False, expected_length=None)[source]
Creates LSTM language model.
Word-level language model for Stack Overflow. Defaults to the model used in:
- Adaptive Federated Optimization
Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295
- Parameters:
vocab_size (
int
) – The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV.embed_size (
int
) – Embedding size for each word.lstm_hidden_size (
int
) – Hidden size for LSTM cells.lstm_num_layers (
int
) – Number of LSTM layers.share_input_output_embeddings (
bool
) – Whether to share the input embeddings with the output logits.expected_length (
Optional
[float
]) – Expected average sentence length used to scale the training loss down by 1. / expected_length. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling.
- Return type:
- Returns:
Model.
Toy Regression
Toy regression models.
- fedjax.models.toy_regression.create_regression_model()[source]
Creates toy regression model.
Matches the model used in:
- Communication-Efficient Agnostic Federated Averaging
Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748
- Return type:
- Returns:
Model