fedjax.tree_util
Utilities for working with tree-like container data structures.
In JAX, the term pytree refers to a tree-like structure built out of container-like Python objects. For more details, see https://jax.readthedocs.io/en/latest/pytrees.html.
- fedjax.tree_util.tree_mean(pytrees_and_weights)
Returns (weighted) mean of input trees and weights.
- Parameters:
pytrees_and_weights (
Iterable
[Tuple
[Any
,float
]]) – Iterable of tuples of pytrees and associated weights.- Return type:
Any
- Returns:
(Weighted) mean of input trees and weights.
- fedjax.tree_util.tree_weight(pytree, weight)
Weights tree leaves by weight.
- Return type:
Any
- fedjax.tree_util.tree_sum(pytrees)
Sums multiple trees together.
- Return type:
Any
- fedjax.tree_util.tree_add(left, right)
Adds two trees together.
- Return type:
Any
- fedjax.tree_util.tree_zeros_like(pytree)
Creates a tree with zeros with same structure as the input.
- Return type:
Any
- fedjax.tree_util.tree_inverse_weight(pytree, weight)
Weights tree leaves by
1 / weight
.- Return type:
Any
- fedjax.tree_util.tree_size(pytree)
Returns total size of all tree leaves.
- Return type:
int
- fedjax.tree_util.tree_clip_by_global_norm(pytree, max_norm)
Clips a pytree of arrays using their global norm.
References
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
- Parameters:
pytree (
Any
) – A pytree to be potentially clipped.max_norm (
float
) – The maximum global norm for a pytree.
- Return type:
Any
- Returns:
A potentially clipped pytree.