fedjax.optimizers¶
Lightweight library for working with optimizers.
Wraps different optimizer libraries in a common interface. 

Creates optimizer from optax gradient transformation chain. 

Modifies 
The Adagrad optimizer. 

The classic Adam optimiser. 

A flexible RMSProp optimiser. 

A canonical Stochastic Gradient Descent optimiser. 

class
fedjax.optimizers.
Optimizer
(init, apply)¶ Wraps different optimizer libraries in a common interface.
Works with optax.
The expected usage of Optimizer is as follows:
# One step of SGD. params = {'w': jnp.array([1, 1, 1])} grads = {'w': jnp.array([2, 3, 4])} optimizer = fedjax.optimizers.sgd(learning_rate=0.1) opt_state = optimizer.init(params) opt_state, params = optimizer.apply(grads, opt_state, params) print(params) # {'w': DeviceArray([0.8, 0.7, 0.6], dtype=float32)}

init
¶ Initializes (possibly empty) PyTree of statistics (optimizer state) given the input model parameters.
 Type
Callable[[Any], Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

apply
¶ Transforms and applies the input gradients to update the optimizer state and model parameters.
 Type
Callable[[Any, Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any], Tuple[Union[jax._src.numpy.lax_numpy.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any]]


fedjax.optimizers.
create_optimizer_from_optax
(opt)¶ Creates optimizer from optax gradient transformation chain.
 Return type
Optimizer

fedjax.optimizers.
ignore_grads_haiku
(optimizer, non_trainable_names)¶ Modifies
optimizer
to ignore gradients fornon_trainable_names
.Nontrainable parameters will have their values set to
None
when passed as input into the Optimizer to prevent any updates.NOTE: This will only work with models implemented in haiku.
 Parameters
optimizer (
Optimizer
) – Base Optimizer.non_trainable_names (
List
[Tuple
[str
,str
]]) – List of tuples of haiku module names and names of given entries in the module data bundle (e.g. parameter name). This list of names will be used to select the nontrainable parameters.
 Return type
Optimizer
 Returns
Optimizer that will ignore gradients for the nontrainable parameters.

fedjax.optimizers.
adagrad
(learning_rate, initial_accumulator_value=0.1, eps=1e06)¶ The Adagrad optimizer.
Adagrad is an algorithm for gradient based optimisation that anneals the learning rate for each parameter during the course of training.
WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.
References
[Duchi et al, 2011](https://jmlr.org/papers/v12/duchi11a.html)
 Parameters
learning_rate (
Union
[float
,Callable
[[Union
[ndarray
,float
,int
]],Union
[ndarray
,float
,int
]]]) – This is a fixed global scaling factor.initial_accumulator_value (
float
) – Initialisation for the accumulator.eps (
float
) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.
 Return type
Optimizer
 Returns
The corresponding Optimizer.

fedjax.optimizers.
adam
(learning_rate, b1=0.9, b2=0.999, eps=1e08, eps_root=0.0)¶ The classic Adam optimiser.
Adam is an SGD variant with learning rate adaptation. The learning_rate used for each weight is computed from estimates of first and secondorder moments of the gradients (using suitable exponential moving averages).
References
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
 Parameters
learning_rate (
Union
[float
,Callable
[[Union
[ndarray
,float
,int
]],Union
[ndarray
,float
,int
]]]) – This is a fixed global scaling factor.b1 (
float
) – The exponential decay rate to track the first moment of past gradients.b2 (
float
) – The exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta)gradients through Adam.
 Return type
Optimizer
 Returns
The corresponding Optimizer.

fedjax.optimizers.
rmsprop
(learning_rate, decay=0.9, eps=1e08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)¶ A flexible RMSProp optimiser.
RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimiser that can be used to switch between several of these variants.
References
[Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [Graves, 2013](https://arxiv.org/abs/1308.0850)
 Parameters
learning_rate (
Union
[float
,Callable
[[Union
[ndarray
,float
,int
]],Union
[ndarray
,float
,int
]]]) – This is a fixed global scaling factor.decay (
float
) – The decay used to track the magnitude of previous gradients.eps (
float
) – A small numerical constant to avoid dividing by zero when rescaling.initial_scale (
float
) – Initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.centered (
bool
) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.momentum (
Optional
[float
]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether nesterov momentum is used.
 Return type
Optimizer
 Returns
The corresponding Optimizer.

fedjax.optimizers.
sgd
(learning_rate, momentum=None, nesterov=False)¶ A canonical Stochastic Gradient Descent optimiser.
This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
References
[Sutskever et al, 2013](http://proceedings.mlr.press/v28/sutskever13.pdf)
 Parameters
learning_rate (
Union
[float
,Callable
[[Union
[ndarray
,float
,int
]],Union
[ndarray
,float
,int
]]]) – This is a fixed global scaling factor.momentum (
Optional
[float
]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether nesterov momentum is used.
 Return type
Optimizer
 Returns
The corresponding Optimizer.