fedjax.models

Model implementations for FedJAX experimental API.

fedjax.models.emnist

EMNIST models.

fedjax.models.shakespeare

Shakespeare recurrent models.

fedjax.models.stackoverflow

Stack Overflow recurrent models.

fedjax.models.toy_regression

Toy regression models.

EMNIST

EMNIST models.

class fedjax.models.emnist.ConvDropoutModule(num_classes)[source]

Bases: Module

Custom haiku module for CNN with dropout.

This must be defined as a custom hk.Module because only a single positional argument is allowed when using hk.Sequential.

class fedjax.models.emnist.Dropout(rate=0.5)[source]

Bases: Module

Dropout haiku module.

fedjax.models.emnist.create_conv_model(only_digits=False)[source]

Creates EMNIST CNN model with dropout with haiku.

Matches the model used in:

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295

Parameters:

only_digits (bool) – Whether to use only digit classes [0-9] or include lower and upper case characters for a total of 62 classes.

Return type:

Model

Returns:

Model

fedjax.models.emnist.create_dense_model(only_digits=False, hidden_units=200)[source]

Creates EMNIST dense net with haiku.

Return type:

Model

fedjax.models.emnist.create_logistic_model(only_digits=False)[source]

Creates EMNIST logistic model with haiku.

Return type:

Model

fedjax.models.emnist.create_stax_dense_model(only_digits=False, hidden_units=200)[source]

Creates EMNIST dense net with stax.

Return type:

Model

Shakespeare

Shakespeare recurrent models.

fedjax.models.shakespeare.create_lstm_model(vocab_size=86, embed_size=8, lstm_hidden_size=256, lstm_num_layers=2)[source]

Creates LSTM language model.

Character-level LSTM for Shakespeare language model. Defaults to the model used in:

Communication-Efficient Learning of Deep Networks from Decentralized Data

H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS 2017. https://arxiv.org/abs/1602.05629

Parameters:
  • vocab_size (int) – The number of possible output characters. This does not include special tokens like PAD, BOS, EOS, or OOV.

  • embed_size (int) – Embedding size for each character.

  • lstm_hidden_size (int) – Hidden size for LSTM cells.

  • lstm_num_layers (int) – Number of LSTM layers.

Return type:

Model

Returns:

Model.

Stack Overflow

Stack Overflow recurrent models.

fedjax.models.stackoverflow.create_lstm_model(vocab_size=10000, embed_size=96, lstm_hidden_size=670, lstm_num_layers=1, share_input_output_embeddings=False, expected_length=None)[source]

Creates LSTM language model.

Word-level language model for Stack Overflow. Defaults to the model used in:

Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295

Parameters:
  • vocab_size (int) – The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV.

  • embed_size (int) – Embedding size for each word.

  • lstm_hidden_size (int) – Hidden size for LSTM cells.

  • lstm_num_layers (int) – Number of LSTM layers.

  • share_input_output_embeddings (bool) – Whether to share the input embeddings with the output logits.

  • expected_length (Optional[float]) – Expected average sentence length used to scale the training loss down by 1. / expected_length. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling.

Return type:

Model

Returns:

Model.

Toy Regression

Toy regression models.

fedjax.models.toy_regression.create_regression_model()[source]

Creates toy regression model.

Matches the model used in:

Communication-Efficient Agnostic Federated Averaging

Jae Ro, Mingqing Chen, Rajiv Mathews, Mehryar Mohri, Ananda Theertha Suresh https://arxiv.org/abs/2104.02748

Return type:

Model

Returns:

Model