fedjax.tree_util

Utilities for working with tree-like container data structures.

In JAX, the term pytree refers to a tree-like structure built out of container-like Python objects. For more details, see https://jax.readthedocs.io/en/latest/pytrees.html.

fedjax.tree_util.tree_mean(pytrees_and_weights)

Returns (weighted) mean of input trees and weights.

Parameters:

pytrees_and_weights (Iterable[Tuple[Any, float]]) – Iterable of tuples of pytrees and associated weights.

Return type:

Any

Returns:

(Weighted) mean of input trees and weights.

fedjax.tree_util.tree_weight(pytree, weight)

Weights tree leaves by weight.

Return type:

Any

fedjax.tree_util.tree_sum(pytrees)

Sums multiple trees together.

Return type:

Any

fedjax.tree_util.tree_add(left, right)

Adds two trees together.

Return type:

Any

fedjax.tree_util.tree_zeros_like(pytree)

Creates a tree with zeros with same structure as the input.

Return type:

Any

fedjax.tree_util.tree_inverse_weight(pytree, weight)

Weights tree leaves by 1 / weight.

Return type:

Any

fedjax.tree_util.tree_l2_norm(pytree)

Returns l2 norm of tree.

Return type:

Array

fedjax.tree_util.tree_size(pytree)

Returns total size of all tree leaves.

Return type:

int

fedjax.tree_util.tree_clip_by_global_norm(pytree, max_norm)

Clips a pytree of arrays using their global norm.

References

[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Parameters:
  • pytree (Any) – A pytree to be potentially clipped.

  • max_norm (float) – The maximum global norm for a pytree.

Return type:

Any

Returns:

A potentially clipped pytree.