# fedjax.metrics

A small library for working with evaluation metrics such as accuracy.

## Stats

 fedjax.metrics.Stat Stat keeps some statistic, along with operations over them. fedjax.metrics.MeanStat Statistic for weighted mean calculation. fedjax.metrics.SumStat Statistic for summing values.

## Metrics

 fedjax.metrics.Metric Metric is the conceptual metric (like accuracy). fedjax.metrics.CrossEntropyLoss Metric for cross entropy loss. fedjax.metrics.Accuracy Metric for accuracy. fedjax.metrics.TopKAccuracy Metric for top k accuracy. fedjax.metrics.SequenceTokenCrossEntropyLoss Metric for token cross entropy loss for a sequence example. fedjax.metrics.SequenceCrossEntropyLoss Metric for total cross entropy loss for a sequence example. fedjax.metrics.SequenceTokenAccuracy Metric for token accuracy for a sequence example. fedjax.metrics.SequenceTokenTopKAccuracy Metric for token top k accuracy for a sequence example. fedjax.metrics.SequenceTokenCount Metric for count of non masked tokens for a sequence example. fedjax.metrics.SequenceCount Metric for count of non masked sequences. fedjax.metrics.SequenceTruncationRate Metric for truncation rate for a sequence example. fedjax.metrics.SequenceTokenOOVRate Metric for out-of-vocabulary (OOV) rate for a sequence example. fedjax.metrics.SequenceLength Metric for length for a sequence example. fedjax.metrics.PerDomainMetric Turns a base metric into one that groups results by domain. fedjax.metrics.ConfusionMatrix Metric for making a Confusion Matrix.

## Miscellaneous

 fedjax.metrics.unreduced_cross_entropy_loss Returns unreduced cross entropy loss. fedjax.metrics.evaluate_batch Evaluates a batch using a metric.

## Quick Overview

To evaluate model predictions, use a Metric object such as Accuracy . We recommend fedjax.core.models.evaluate_model() in most scenarios, which runs model prediction, and evaluation, on batches of N examples at a time for greater computational efficiency:

# Mock out Model.
model = fedjax.Model(
init=lambda _: None,  # Unused.
apply_for_train=lambda _, _, _: None,  # Unused.
apply_for_eval=lambda _, batch: batch.get('pred'),
train_loss=lambda _, _: None,  # Unused.
eval_metrics={'accuracy': metrics.Accuracy()})
params = None  # Unused.
batches = [{'y': np.array([1, 0]),
'pred': np.array([[1.2, 0.4], [2.3, 0.1]])},
{'y': np.array([1, 1]),
'pred': np.array([[0.3, 3.2], [2.1, 4.3]])}]
results = fedjax.evaluate_model(model, params, batches)
print(results)
# {'accuracy': 0.75}


A Metric object has 2 methods:

Most Metric follow the following convention for convenience:

• example is a dict-like object from str to jnp.ndarray.

• prediction is either a single jnp.ndarray, or a dict-like object from str to jnp.ndarray.

Conceptually, we can also use a simple for loop to evaluate a collection of examples and model predictions:

# By default, the Accuracy metric treats example['y'] as the true label,
# and prediction as a single jnp.ndarray of class scores.
metric = Accuracy()
stat = metric.zero()
# We are iterating over individual examples, not batches.
for example, prediction in [({'y': jnp.array(1)}, jnp.array([0., 1.])),
({'y': jnp.array(0)}, jnp.array([1., 0.])),
({'y': jnp.array(1)}, jnp.array([1., 0.])),
({'y': jnp.array(0)}, jnp.array([2., 0.]))]:
stat = stat.merge(metric.evaluate_example(example, prediction))
print(stat.result())
# 0.75


In practice, for greater computational efficiency, we run model prediction not on a single example, but a batch of N examples at a time. fedjax.core.models.evaluate_model() provides a simple way to do so. Under the hood, it calls evaluate_batch()

metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
stat = stat.merge(evaluate_batch(metric, batch_example, batch_prediction))
print(stat.result())
# 0.75


## Under the hood

For most users, it is sufficient to know how to use existing Metric subclasses such as Accuracy with fedjax.core.models.evaluate_model() . This section is intended for those who would like to write new metrics.

### From algebraic structures to Metric and Stat

There are 2 abstraction in this library, Metric and Stat . Before going into details of these classes, let’s first consider a few abstract properties related to evaluation metrics, using accuracy as an example.

When evaluating accuracy on a dataset, we wish to know the proportion of examples that are correctly predicted by the model. Because a dataset might be too large to fit into memory, we need to divide the work by partitioning the dataset into smaller subsets, evaluate each separately, and finally somehow combine the results. Assuming the subsets can be of different sizes, although the accuracy value is a single number, we cannot just average the accuracy values from each partition to arrive at the overall accuracy. Instead, we need 2 numbers from each subset:

• The number of examples in this subset,

• How many of them are correctly predicted.

We call these two numbers from each subset a statistic. The domain (the set of possible values) of the statistic in the case of accuracy is

$\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}$

With the numbers of examples and correct predictions from 2 disjoint subsets, we add the numbers up to get the number of examples and correct predictions for the union of the 2 subsets. We call this operation from 2 statistics into 1 a $$merge$$ operation.

Let $$f(S)$$ be the function that gives us the statistic from a subset of examples. It is easy to see for two disjoint subsets $$A$$ and $$B$$ , $$merge(f(A), f(B))$$ should be equal to $$f(A ∪ B)$$ . If no such $$merge$$ exists, we cannot evaluate the dataset by partitioning the work. This requirement alone implies the domain of a statistic, and the $$merge$$ operation forms a specific algebraic structure (a commutative monoid).

• $$I := f(empty set)$$ is one and the only identity element w.r.t.
$$merge$$ (i.e. $$merge(I, x) == merge(x, I) == x$$ ).
• $$merge()$$ is commutative and associative.

Further, we can see $$f(S)$$ can be defined just knowing two types of values:

• $$f(empty set)$$, i.e. $$I$$ ;

• $$f({x})$$ for any single example $$x$$ .

For any other subset $$S$$ , we can derive the value of $$f(S)$$ using these values and $$merge$$ . Metric is simply the $$f(S)$$ function above, defined in 2 corresponding parts:

On the other hand, Stat stores a single statistic, a merge() method for combining two, and a result() method for producing the final metric value.

To implement Accuracy as a subclass of Metric , we first need to know what Stat to use. In this case, the statistic domain and merge is implemented by a MeanStat . A MeanStat holds two values:

• accum is the weighted sum of values, i.e. the number of correct
predictions in the case of accuracy.
• weight is the sum of weights, i.e. the number of examples in
the case of accuracy.

merge() adds up the respective accum and weight from two MeanStat objects.

Sometimes, a new Stat subclass is necessary. In that case, it is very important to make sure the implementation has a clear definition of the domain, and the merge() operation adheres to the properties regarding identity element, commutativity, and associativity (e.g. if we unknowingly allow pairs of $$(x, 0)$$ for $$x != 0$$ into the domain of a MeanStat , $$merge((x, 0), (a, b))$$ will produce a statistic that leads to incorrect final metric values, i.e. $$(a+x)/b$$ , instead of $$a/b$$ ).

### Batching Stat s

In most cases, the final value of an evaluation is simply a scalar, and the corresponding statistic is also a tuple of a few scalar values. However, for the same reason why jax.vmap() is a lot more efficient than a for loop, it is a lot more efficient to store multiple Stat values as a Stat of arrays, instead of a list of Stat objects. Thus instead of a list [MeanStat(1, 2), MeanStat(3, 4), MeanStat(5, 6)] (call these “rank 0” Stat s), we can batch the 3 statitstics as a single MeanStat(jnp.array([1, 3, 5]), jnp.array([2, 4, 6])). A Stat object holding a single statistic is a “rank 0” Stat . A Stat object holding a vector of statistics is a “rank 1” Stat . Similarly, a Stat object may also hold a matrix, a 3D array, etc, of statistics. These are higher rank Stat s. In most cases, the rank 1 implementation of merge() and :meth~Stat.result automatically generalizes to higher ranks as elementwise operations.

In the end, we want just 1 final metric value instead of a length 3 vector, or a 2x2 matrix, of metric values. To finally go back to a single statistic (), we need to merge() statistics stored in these arrays. Each Stat subclass provides a reduce() method to do just that. The combination of jax.vmap() over Metric.evaluate_example(), and Stat.reduce() , is how we get an efficient evaluate_batch() function (of course, the real evaluate_batch() is jax.jit() ‘d so that the same jax.vmap() transformation etc does not need to happen over and over.:

metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
# Get a batch of statistics as a single Stat object.
batch_stat = jax.vmap(metric.evaluate_example)(batch_example,
batch_prediction)
# Merge the reduced single statistic onto the accumulator.
stat = stat.merge(batch_stat.reduce())
print(stat.result())
# 0.75


Being able to batch Stat s also allows us to do other interesting things, for example,

### Creating a new Metric

Most likely, a new metric will just return a MeanStat or a SumStat. If that’s the case, simply follow the guidelines in Metric ‘s class docstring.

If a new Stat is necessary, follow the guidelines in Stat ‘s docstring.

class fedjax.metrics.Stat

Stat keeps some statistic, along with operations over them.

Most users will only need to interact with a Stat object via result()

For those who need to create new metrics, please first read the Under the hood section of the module docstring.

Most Stat’s domain (the set of possible statistic values) has constraints, it is thus usually a good practice to offer and use factory methods to construct new Stat objects instead of directly assigning the fields.

To work with various jax constructs, a concrete Stat should be a PyTree. This is easily achieved with fedjax.dataclass.

A Stat may hold either a single statistic (a rank 0 Stat), or an array of statistics (a higher rank Stat). result() and merge() only needs to work on a rank 0 Stat reduce() only needs to work on a higher rank Stat

abstract merge(other)

Merges two Stat objects into a new Stat with merged statistics.

Parameters:

other (Stat) – Another Stat object of the same type.

Return type:

Stat

Returns:

A new Stat object of the same type with merged statistics.

abstract reduce(axis=0)

Reduces a higher rank statistic along a given axis.

See the class docstring for details.

Parameters:

axis (Optional[int]) – An integer axis index, or None.

Return type:

Stat

Returns:

A new Stat object of the same type.

abstract result()

Calculates the metric value from the statistic value.

For example, MeanStat.result() calculates a weighted average.

Return type:

Array

Returns:

The return value must be a jnp.ndarray.

class fedjax.metrics.MeanStat(accum, weight)

Bases: Stat

Statistic for weighted mean calculation.

Prefer using the MeanStat.new() factory method instead of directly assigning to fields.

Example:

stat_0 = MeanStat.new(accum=1, weight=2)
stat_1 = MeanStat.new(accum=2, weight=3)
merged_stat = stat_0.merge(stat_1)
print(merged_stat)
# MeanState(accum=3, weight=5) => 0.6

stat = MeanStat.new(jnp.array([1, 2, 4]), jnp.array([1, 1, 0]))
reduced_stat = stat.reduce()
print(reduced_stat)
# MeanStat(accum=3, weight=2) => 1.5

accum

The weighted sum.

Type:

jax.Array

weight

The sum of weights.

Type:

jax.Array

classmethod new(accum, weight)

Creates a sanitized MeanStat.

The domain of a weighted mean statistic is:

$\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}$

new() sanitizes values outside the domain into the identity (zeros).

Parameters:
• accum – A value convertible to jnp.ndarray.

• weight – A value convertible to jnp.ndarray.

Return type:

MeanStat

Returns:

The sanitized MeanStat.

class fedjax.metrics.SumStat(accum)

Bases: Stat

Statistic for summing values.

Example:

stat_0 = SumStat.new(accum=1)
stat_1 = SumStat.new(accum=2)
merged_stat = stat_0.merge(stat_1)
print(merged_stat)
# SumStat(accum=3) => 3

stat = SumStat.new(jnp.array([1, 2, 1]))
reduced_stat = stat.reduce()
print(reduced_stat)
# SumStat(accum=4) => 4

accum

Sum of values.

Type:

jax.Array

classmethod new(accum)

Creates a sanitized SumStat.

Return type:

SumStat

class fedjax.metrics.Metric

Metric is the conceptual metric (like accuracy).

It defines two methods:

Given a Metric object m, let

• u = m.zero()

• v = m.evaluate_example(...)

We require that

• type(u) == type(v).

• u.merge(v) == v.merge(u) == v.

• Components of u has the same shape as the counter parts in v.

abstract evaluate_example(example, prediction)

Evaluates a single example.

e.g. for accuracy: MeanStat.new(num_correct, num_total)

Parameters:
• example (Mapping[str, Array]) – A single input example (e.g. one sentence for language).

• prediction (Union[Array, Mapping[str, Array]]) – Output for example from fedjax.core.models.Model.apply_for_eval().

Return type:

Stat

Returns:

Stat value.

abstract zero()

Returns a Stat such that merging with it is an identity operation.

e.g. for accuracy: MeanStat.new(0., 0.)

Return type:

Stat

Returns:

Stat identity value.

class fedjax.metrics.CrossEntropyLoss(target_key='y', pred_key=None)

Bases: Metric

Metric for cross entropy loss.

Example:

example = {'y': jnp.array(1)}
prediction = jnp.array([1.2, 0.4])
metric = CrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.1711007, weight=1) => 1.1711007

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

class fedjax.metrics.Accuracy(target_key='y', pred_key=None)

Bases: Metric

Metric for accuracy.

Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0, 0, 1])
metric = Accuracy()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1, weight=1) => 1

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

class fedjax.metrics.TopKAccuracy(k, target_key='y', pred_key=None)

Bases: Metric

Metric for top k accuracy.

This metric computes the number of times where the correct class is among the top k classes predicted.

Example: top 3 accuracy

• Dog => [Dog, Cat, Bird, Mouse, Penguin] ✓

• Cat => [Bird, Mouse, Cat, Penguin, Dog] ✓

• Dog => [Dog, Cat, Bird, Penguin, Mouse] ✓

• Bird => [Bird, Cat, Mouse, Penguin, Dog] ✓

• Cat => [Cat, Bird, Mouse, Dog, Penguin] ✓

• Cat => [Cat, Mouse, Dog, Penguin, Bird] ✓

• Mouse => [Penguin, Cat, Dog, Mouse, Bird] x

• Penguin => [Dog, Mouse, Cat, Penguin, Bird] x

6 correct predictions in top 3 predicted classes / 8 total examples = .75 top 3 accuracy

Top k accuracy, also known as top n accuracy, is a useful metric when it comes to recommendations. One example would be the word recommendations on a virtual keyboard where three suggested words are displayed.

For k=1, we strongly recommend using Accuracy to avoid an unnecessary argsort. k < 1 will return 0. and k >= num_classes will return 1.

If two or more classes have the same prediction, the classes will be considered in order of lowest to highest indices.

Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0, 0.5, 0.2])
metric = TopKAccuracy(k=2)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1, weight=1) => 1

k

Number of top elements to look at for computing accuracy.

Type:

int

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

Bases: Metric

Metric for token cross entropy loss for a sequence example.

Example:

example = {'y': jnp.array([1, 0, 1])}
prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
metric = SequenceTokenCrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.2246635, weight=2) => 0.61233175

per_position_metric = SequenceTokenCrossEntropyLoss(per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1.1711007, 0., 0.05356275], weight=[1., 0., 1.]) => [1.1711007, 0., 0.05356275]

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

Bases: Metric

Metric for total cross entropy loss for a sequence example.

Example:

example = {'y': jnp.array([1, 0, 1])}
prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]])
metric = SequenceCrossEntropyLoss()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=1.2246635, weight=1) => 1.2246635

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Bases: Metric

Metric for token accuracy for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 1, 3, 0])}
# prediction = [1, 0, 2, 1, 3, 0].
prediction = jnp.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0],
[0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0]])
logits_mask = (0., 0., 0., jnp.NINF)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=3, weight=5) => 0.6

print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.

Type:

Optional[Tuple[float, …]]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

Bases: Metric

Metric for token top k accuracy for a sequence example.

For more information on the top k accuracy metric, refer to the TopKAccuracy docstring.

Example:

example = {'y': jnp.array([1, 2, 2, 1, 3, 0])}
prediction = jnp.array([[0, 1, 0.5, 0], [1, 0.5, 0, 0], [0.8, 0, 0.7, 0],
[0.5, 1, 0, 0], [0, 0.5, 0, 1], [0.5, 0, 0.9, 0]])
logits_mask = (0., 0., 0., jnp.NINF)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=3, weight=5) => 0.6

print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]

k

Number of top elements to look at for computing accuracy.

Type:

int

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.

Type:

Optional[Tuple[float, …]]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

Bases: Metric

Metric for count of non masked tokens for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
print(metric.evaluate_example(example, prediction))
# SumStat(accum=3) => 3

target_key

Key name in example for target.

Type:

str

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Bases: Metric

Metric for count of non masked sequences.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
empty_example = {'y': jnp.array([0, 0, 0, 0, 0, 0, 0])}
prediction = jnp.array([])  # Unused.
print(metric.evaluate_example(example, prediction))
# SumStat(accum=1)
print(metric.evaluate_example(empty_example, prediction))
# SumStat(accum=0)

target_key

Key name in example for target.

Type:

str

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Bases: Metric

Metric for truncation rate for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 4])}
truncated_example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 3])}
prediction = jnp.array([])  # Unused.
metric = SequenceTruncationRate(eos_target_value=4)
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=0, weight=1) => 0
print(metric.evaluate_example(truncated_example, prediction))
# MeanStat(accum=1, weight=1) => 1

eos_target_value

Target value denoting end of sequence. Truncated sequences will not have this value.

Type:

int

target_key

Key name in example for target.

Type:

str

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

Bases: Metric

Metric for out-of-vocabulary (OOV) rate for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = SequenceTokenOOVRate(oov_target_values=(2,))
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=2, weight=5) => 0.4

per_position_metric = SequenceTokenOOVRate(oov_target_values=(2,), per_position=True)
print(per_position_metric.evaluate_example(example, prediction))
# MeanStat(accum=[0., 1., 1., 0., 0., 0., 0.], weight=[1., 1., 1., 1., 1., 0., 0.]) => [0. 1. 1. 0. 0. 0. 0.]

oov_target_values

Target values denoting out-of-vocabulary values.

Type:

Tuple[int, …]

target_key

Key name in example for target.

Type:

str

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

per_position

Whether to keep output statistic per position or sum across positions for the entire sequence.

Type:

bool

Bases: Metric

Metric for length for a sequence example.

Example:

example = {'y': jnp.array([1, 2, 3, 4, 0, 0])}
prediction = jnp.array([])  # Unused.
metric = SequenceLength()
print(metric.evaluate_example(example, prediction))
# MeanStat(accum=4, weight=1) => 4

target_key

Key name in example for target.

Type:

str

Target values that should be ignored in computation. This is typically used to ignore padding values in computation.

Type:

Tuple[int, …]

class fedjax.metrics.PerDomainMetric(base, num_domains, domain_id_key='domain_id')

Bases: Metric

Turns a base metric into one that groups results by domain.

This is useful in algorithms such as AgnosticFedAvg.

example is expected to contain a feature named domain_id_key, which stores the integer domain id in [0, num_domains). PerDomain accumulates base ‘s Stat within each domain. If the base Metric returns a Stat whose result is of shape X, then the Stat returned by PerDomain will produce a result of shape (num_domains,) + X. See Batching Stat s for the higher rank Stat mechanism enabling this.

Example:

per_domain_accuracy = PerDomain(metrics.Accuracy(), num_domains=3)
batch_example = {
'domain_id': jnp.array([0, 0, 1, 2]),
'y': jnp.array([0, 1, 0, 1])
}
batch_prediction = jnp.array([[0., 1.], [2., 3.], [4., 5.], [6., 7.]])
print(
evaluate_batch(per_domain_accuracy, batch_example,
batch_prediction).result())
# [0.5 0.  1. ]

class fedjax.metrics.ConfusionMatrix(num_classes, target_key='y', pred_key=None)

Bases: Metric

Metric for making a Confusion Matrix.

A confusion matrix is an nxn matrix often used to describe the performance of a classification model on a set of test data for which the true values are known. The model’s predictions are represented through the columns, and the known data values through the rows. This allows one to view in which areas the model is doing well, as well as where there is room for improvement. For each row in the confusion matrix, if there are a lot of numbers outside of the main diagonal, the model is not doing so well in respect to when it is supposed to output that row’s relative output class.

Theoretical Example:

            Predicted P     Predicted N

Actual P       TP               FN

Actual N       FP               TN

**This is for a binary classification model but the same concept applies
to any model with n outputs. Notice that the TPs and TNs will always
lie in the main diagonal of the matrix.


Example:

example = {'y': jnp.array(2)}
prediction = jnp.array([0., 1., 0.])
metric = ConfusionMatrix(num_classes=3)
print(metric.evaluate_example(example, prediction))
# SumStat(accum=DeviceArray([[0., 0., 0.],
#                            [0., 0., 0.],
#                            [0., 1., 0.]], dtype=float32)) => [[0. 0. 0.]
#                                                               [0. 0. 0.]
#                                                               [0. 1. 0.]]

target_key

Key name in example for target.

Type:

str

pred_key

Key name in prediction for unnormalized model output pred.

Type:

Optional[str]

num_classes

Number of output classes of the model. Used to generate a matrix of shape [num_classes, num_classes].

Type:

int

fedjax.metrics.unreduced_cross_entropy_loss(targets, preds, is_sparse_targets=True)

Returns unreduced cross entropy loss.

Return type:

Array

Stat