fedjax.optimizers

Lightweight library for working with optimizers.

fedjax.optimizers.Optimizer

Wraps different optimizer libraries in a common interface.

fedjax.optimizers.create_optimizer_from_optax

Creates optimizer from optax gradient transformation chain.

fedjax.optimizers.ignore_grads_haiku

Modifies optimizer to ignore gradients for non_trainable_names.

fedjax.optimizers.adagrad

The Adagrad optimizer.

fedjax.optimizers.adam

The classic Adam optimizer.

fedjax.optimizers.rmsprop

A flexible RMSProp optimizer.

fedjax.optimizers.sgd

A canonical Stochastic Gradient Descent optimizer.


class fedjax.optimizers.Optimizer(init, apply)

Wraps different optimizer libraries in a common interface.

Works with optax.

The expected usage of Optimizer is as follows:

# One step of SGD.
params = {'w': jnp.array([1, 1, 1])}
grads = {'w': jnp.array([2, 3, 4])}
optimizer = fedjax.optimizers.sgd(learning_rate=0.1)
opt_state = optimizer.init(params)
opt_state, params = optimizer.apply(grads, opt_state, params)
print(params)
# {'w': DeviceArray([0.8, 0.7, 0.6], dtype=float32)}
init

Initializes (possibly empty) PyTree of statistics (optimizer state) given the input model parameters.

Type

Callable[[Any], Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]

apply

Transforms and applies the input gradients to update the optimizer state and model parameters.

Type

Callable[[Any, Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any], Tuple[Union[jax._src.numpy.ndarray.ndarray, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Any]]

fedjax.optimizers.create_optimizer_from_optax(opt)

Creates optimizer from optax gradient transformation chain.

Return type

Optimizer

fedjax.optimizers.ignore_grads_haiku(optimizer, non_trainable_names)

Modifies optimizer to ignore gradients for non_trainable_names.

Non-trainable parameters will have their values set to None when passed as input into the Optimizer to prevent any updates.

NOTE: This will only work with models implemented in haiku.

Parameters
  • optimizer (Optimizer) – Base Optimizer.

  • non_trainable_names (List[Tuple[str, str]]) – List of tuples of haiku module names and names of given entries in the module data bundle (e.g. parameter name). This list of names will be used to select the non-trainable parameters.

Return type

Optimizer

Returns

Optimizer that will ignore gradients for the non-trainable parameters.


fedjax.optimizers.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-06)

The Adagrad optimizer.

Adagrad is an algorithm for gradient based optimisation that anneals the learning rate for each parameter during the course of training.

WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.

References

[Duchi et al, 2011](https://jmlr.org/papers/v12/duchi11a.html)

Parameters
  • learning_rate (Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]]) – This is a fixed global scaling factor.

  • initial_accumulator_value (float) – Initialisation for the accumulator.

  • eps (float) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.

Return type

Optimizer

Returns

The corresponding Optimizer.

fedjax.optimizers.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)

The classic Adam optimizer.

Adam is an SGD variant with learning rate adaptation. The learning_rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages).

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters
  • learning_rate (Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]]) – This is a fixed global scaling factor.

  • b1 (float) – The exponential decay rate to track the first moment of past gradients.

  • b2 (float) – The exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

Return type

Optimizer

Returns

The corresponding Optimizer.

fedjax.optimizers.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)

A flexible RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

References

[Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [Graves, 2013](https://arxiv.org/abs/1308.0850)

Parameters
  • learning_rate (Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]]) – This is a fixed global scaling factor.

  • decay (float) – The decay used to track the magnitude of previous gradients.

  • eps (float) – A small numerical constant to avoid dividing by zero when rescaling.

  • initial_scale (float) – Initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.

  • centered (bool) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.

  • momentum (Optional[float]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether nesterov momentum is used.

Return type

Optimizer

Returns

The corresponding Optimizer.

fedjax.optimizers.sgd(learning_rate, momentum=None, nesterov=False)

A canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

References

[Sutskever et al, 2013](http://proceedings.mlr.press/v28/sutskever13.pdf)

Parameters
  • learning_rate (Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]]) – This is a fixed global scaling factor.

  • momentum (Optional[float]) – The decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether nesterov momentum is used.

Return type

Optimizer

Returns

The corresponding Optimizer.