fedjax.metrics
A small library for working with evaluation metrics such as accuracy.
Stats
Stat keeps some statistic, along with operations over them. |
|
Statistic for weighted mean calculation. |
|
Statistic for summing values. |
Metrics
Metric is the conceptual metric (like accuracy). |
|
Metric for cross entropy loss. |
|
Metric for accuracy. |
|
Metric for top k accuracy. |
|
Metric for token cross entropy loss for a sequence example. |
|
Metric for total cross entropy loss for a sequence example. |
|
Metric for token accuracy for a sequence example. |
|
Metric for token top k accuracy for a sequence example. |
|
Metric for count of non masked tokens for a sequence example. |
|
Metric for count of non masked sequences. |
|
Metric for truncation rate for a sequence example. |
|
Metric for out-of-vocabulary (OOV) rate for a sequence example. |
|
Metric for length for a sequence example. |
|
Turns a base metric into one that groups results by domain. |
|
Metric for making a Confusion Matrix. |
Miscellaneous
Returns unreduced cross entropy loss. |
|
Evaluates a batch using a metric. |
Quick Overview
To evaluate model predictions, use a Metric
object such as Accuracy
.
We recommend fedjax.core.models.evaluate_model()
in most scenarios,
which runs model prediction, and evaluation, on batches of N examples at a time
for greater computational efficiency:
# Mock out Model.
model = fedjax.Model(
init=lambda _: None, # Unused.
apply_for_train=lambda _, _, _: None, # Unused.
apply_for_eval=lambda _, batch: batch.get('pred'),
train_loss=lambda _, _: None, # Unused.
eval_metrics={'accuracy': metrics.Accuracy()})
params = None # Unused.
batches = [{'y': np.array([1, 0]),
'pred': np.array([[1.2, 0.4], [2.3, 0.1]])},
{'y': np.array([1, 1]),
'pred': np.array([[0.3, 3.2], [2.1, 4.3]])}]
results = fedjax.evaluate_model(model, params, batches)
print(results)
# {'accuracy': 0.75}
A Metric
object has 2 methods:
zero()
: Initial value for accumulating the statistic for this metric.evaluate_example()
: Returns the statistic from evaluating a single example, given the trainingexample
and the modelprediction
.
Most Metric
follow the following convention for convenience:
example
is a dict-like object fromstr
tojnp.ndarray
.prediction
is either a singlejnp.ndarray
, or a dict-like object fromstr
tojnp.ndarray
.
Conceptually, we can also use a simple for loop to evaluate a collection of examples and model predictions:
# By default, the `Accuracy` metric treats `example['y']` as the true label,
# and `prediction` as a single `jnp.ndarray` of class scores.
metric = Accuracy()
stat = metric.zero()
# We are iterating over individual examples, not batches.
for example, prediction in [({'y': jnp.array(1)}, jnp.array([0., 1.])),
({'y': jnp.array(0)}, jnp.array([1., 0.])),
({'y': jnp.array(1)}, jnp.array([1., 0.])),
({'y': jnp.array(0)}, jnp.array([2., 0.]))]:
stat = stat.merge(metric.evaluate_example(example, prediction))
print(stat.result())
# 0.75
In practice, for greater computational efficiency, we run model prediction not
on a single example, but a batch of N examples at a time.
fedjax.core.models.evaluate_model()
provides a simple way to do so.
Under the hood, it calls evaluate_batch()
metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
stat = stat.merge(evaluate_batch(metric, batch_example, batch_prediction))
print(stat.result())
# 0.75
Under the hood
For most users, it is sufficient to know how to use existing Metric
subclasses such as Accuracy
with
fedjax.core.models.evaluate_model()
.
This section is intended for those who would like to write new metrics.
From algebraic structures to Metric
and Stat
There are 2 abstraction in this library, Metric
and Stat
.
Before going into details of these classes, let’s first consider a few abstract
properties related to evaluation metrics, using accuracy as an example.
When evaluating accuracy on a dataset, we wish to know the proportion of examples that are correctly predicted by the model. Because a dataset might be too large to fit into memory, we need to divide the work by partitioning the dataset into smaller subsets, evaluate each separately, and finally somehow combine the results. Assuming the subsets can be of different sizes, although the accuracy value is a single number, we cannot just average the accuracy values from each partition to arrive at the overall accuracy. Instead, we need 2 numbers from each subset:
The number of examples in this subset,
How many of them are correctly predicted.
We call these two numbers from each subset a statistic. The domain (the set of possible values) of the statistic in the case of accuracy is
With the numbers of examples and correct predictions from 2 disjoint subsets, we add the numbers up to get the number of examples and correct predictions for the union of the 2 subsets. We call this operation from 2 statistics into 1 a \(merge\) operation.
Let \(f(S)\) be the function that gives us the statistic from a subset of examples. It is easy to see for two disjoint subsets \(A\) and \(B\) , \(merge(f(A), f(B))\) should be equal to \(f(A ∪ B)\) . If no such \(merge\) exists, we cannot evaluate the dataset by partitioning the work. This requirement alone implies the domain of a statistic, and the \(merge\) operation forms a specific algebraic structure (a commutative monoid).
- \(I := f(empty set)\) is one and the only identity element w.r.t.\(merge\) (i.e. \(merge(I, x) == merge(x, I) == x\) ).
\(merge()\) is commutative and associative.
Further, we can see \(f(S)\) can be defined just knowing two types of values:
\(f(empty set)\), i.e. \(I\) ;
\(f({x})\) for any single example \(x\) .
For any other subset \(S\) , we can derive the value of \(f(S)\) using
these values and \(merge\) . Metric
is simply the \(f(S)\) function
above, defined in 2 corresponding parts:
Metric.zero()
is \(f(empty set)\) .Metric.evaluate_example()
is \(f({x})\) for a single example.
On the other hand, Stat
stores a single statistic, a merge()
method for combining two, and a result()
method for producing the
final metric value.
To implement Accuracy
as a subclass of Metric
, we first need to
know what Stat
to use. In this case, the statistic domain and merge is
implemented by a MeanStat
. A MeanStat
holds two values:
accum
is the weighted sum of values, i.e. the number of correctpredictions in the case of accuracy.weight
is the sum of weights, i.e. the number of examples inthe case of accuracy.
merge()
adds up the respective accum
and
weight
from two MeanStat
objects.
Sometimes, a new Stat
subclass is necessary. In that case, it is very
important to make sure the implementation has a clear definition of the domain,
and the merge()
operation adheres to the properties regarding identity
element, commutativity, and associativity (e.g. if we unknowingly allow pairs of
\((x, 0)\) for \(x != 0\) into the domain of a MeanStat
,
\(merge((x, 0), (a, b))\) will produce a statistic that leads to incorrect
final metric values, i.e. \((a+x)/b\) , instead of \(a/b\) ).
Batching Stat
s
In most cases, the final value of an evaluation is simply a scalar, and the
corresponding statistic is also a tuple of a few scalar values. However, for the
same reason why jax.vmap()
is a lot more efficient than a for loop, it is a
lot more efficient to store multiple Stat
values as a Stat
of arrays,
instead of a list of Stat
objects. Thus instead of a list
[MeanStat(1, 2), MeanStat(3, 4), MeanStat(5, 6)]
(call these “rank 0” Stat
s), we can batch the 3 statitstics as a single
MeanStat(jnp.array([1, 3, 5]), jnp.array([2, 4, 6]))
.
A Stat
object holding a single statistic is a “rank 0” Stat
.
A Stat
object holding a vector of statistics is a “rank 1” Stat
.
Similarly, a Stat
object may also hold a matrix, a 3D array, etc, of statistics.
These are higher rank Stat
s. In most cases, the rank 1 implementation of
merge()
and :meth`~Stat.result` automatically generalizes to higher ranks as
elementwise operations.
In the end, we want just 1 final metric value instead of a length 3 vector, or a
2x2 matrix, of metric values. To finally go back to a single statistic (), we
need to merge()
statistics stored in these arrays. Each Stat
subclass
provides a reduce()
method to do just that. The combination of jax.vmap()
over Metric.evaluate_example()
, and Stat.reduce()
, is how we get an
efficient evaluate_batch()
function (of course, the real evaluate_batch()
is jax.jit()
‘d so that the same jax.vmap()
transformation etc does not
need to happen over and over.:
metric = Accuracy()
stat = metric.zero()
# We are iterating over batches.
for batch_example, batch_prediction in [
({'y': jnp.array([1, 0])}, jnp.array([[0., 1.], [1., 0.]])),
({'y': jnp.array([1, 0])}, jnp.array([[1., 0.], [2., 0.]]))]:
# Get a batch of statistics as a single Stat object.
batch_stat = jax.vmap(metric.evaluate_example)(batch_example,
batch_prediction)
# Merge the reduced single statistic onto the accumulator.
stat = stat.merge(batch_stat.reduce())
print(stat.result())
# 0.75
Being able to batch Stat
s also allows us to do other interesting things, for
example,
evaluate_batch()
accepts an optional per-examplemask
so it can workon padded batches.- We can define a
PerDomainMetric
metric for any base metric so that we can getaccuracy where examples are partitioned by a domain id.
Creating a new Metric
Most likely, a new metric will just return a MeanStat
or a SumStat
.
If that’s the case, simply follow the guidelines in Metric
‘s class docstring.
If a new Stat
is necessary, follow the guidelines in Stat
‘s docstring.
- class fedjax.metrics.Stat
Stat keeps some statistic, along with operations over them.
Most users will only need to interact with a
Stat
object viaresult()
For those who need to create new metrics, please first read the Under the hood section of the module docstring.
Most
Stat
’s domain (the set of possible statistic values) has constraints, it is thus usually a good practice to offer and use factory methods to construct newStat
objects instead of directly assigning the fields.To work with various jax constructs, a concrete
Stat
should be a PyTree. This is easily achieved withfedjax.dataclass
.A
Stat
may hold either a single statistic (a rank 0Stat
), or an array of statistics (a higher rankStat
).result()
andmerge()
only needs to work on a rank 0Stat
reduce()
only needs to work on a higher rankStat
- abstract merge(other)
Merges two Stat objects into a new Stat with merged statistics.
- abstract reduce(axis=0)
Reduces a higher rank statistic along a given
axis
.See the class docstring for details.
- Parameters:
axis (
Optional
[int
]) – An integer axis index, orNone
.- Return type:
- Returns:
A new Stat object of the same type.
- class fedjax.metrics.MeanStat(accum, weight)
Bases:
Stat
Statistic for weighted mean calculation.
Prefer using the
MeanStat.new()
factory method instead of directly assigning to fields.Example:
stat_0 = MeanStat.new(accum=1, weight=2) stat_1 = MeanStat.new(accum=2, weight=3) merged_stat = stat_0.merge(stat_1) print(merged_stat) # MeanState(accum=3, weight=5) => 0.6 stat = MeanStat.new(jnp.array([1, 2, 4]), jnp.array([1, 1, 0])) reduced_stat = stat.reduce() print(reduced_stat) # MeanStat(accum=3, weight=2) => 1.5
- classmethod new(accum, weight)
Creates a sanitized MeanStat.
The domain of a weighted mean statistic is:
\[\{(0, 0)\} ∪ \{(a, b) | a >= 0, b > 0\}\]new() sanitizes values outside the domain into the identity (zeros).
- Parameters:
accum – A value convertible to
jnp.ndarray
.weight – A value convertible to
jnp.ndarray
.
- Return type:
- Returns:
The sanitized MeanStat.
- class fedjax.metrics.SumStat(accum)
Bases:
Stat
Statistic for summing values.
Example:
stat_0 = SumStat.new(accum=1) stat_1 = SumStat.new(accum=2) merged_stat = stat_0.merge(stat_1) print(merged_stat) # SumStat(accum=3) => 3 stat = SumStat.new(jnp.array([1, 2, 1])) reduced_stat = stat.reduce() print(reduced_stat) # SumStat(accum=4) => 4
- class fedjax.metrics.Metric
Metric is the conceptual metric (like accuracy).
It defines two methods:
evaluate_example()
evaluates a single example, and returns aStat
object.zero()
returns the identity value for whatevaluate_example()
returns.
Given a
Metric
objectm
, letu = m.zero()
v = m.evaluate_example(...)
We require that
type(u) == type(v)
.u.merge(v) == v.merge(u) == v
.Components of
u
has the same shape as the counter parts inv
.
- abstract evaluate_example(example, prediction)
Evaluates a single example.
e.g. for accuracy:
MeanStat.new(num_correct, num_total)
- class fedjax.metrics.CrossEntropyLoss(target_key='y', pred_key=None)
Bases:
Metric
Metric for cross entropy loss.
Example:
example = {'y': jnp.array(1)} prediction = jnp.array([1.2, 0.4]) metric = CrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.1711007, weight=1) => 1.1711007
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.Accuracy(target_key='y', pred_key=None)
Bases:
Metric
Metric for accuracy.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0, 0, 1]) metric = Accuracy() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1, weight=1) => 1
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.TopKAccuracy(k, target_key='y', pred_key=None)
Bases:
Metric
Metric for top k accuracy.
This metric computes the number of times where the correct class is among the top k classes predicted.
Example: top 3 accuracy
Dog => [Dog, Cat, Bird, Mouse, Penguin] ✓
Cat => [Bird, Mouse, Cat, Penguin, Dog] ✓
Dog => [Dog, Cat, Bird, Penguin, Mouse] ✓
Bird => [Bird, Cat, Mouse, Penguin, Dog] ✓
Cat => [Cat, Bird, Mouse, Dog, Penguin] ✓
Cat => [Cat, Mouse, Dog, Penguin, Bird] ✓
Mouse => [Penguin, Cat, Dog, Mouse, Bird] x
Penguin => [Dog, Mouse, Cat, Penguin, Bird] x
6 correct predictions in top 3 predicted classes / 8 total examples = .75 top 3 accuracy
Top k accuracy, also known as top n accuracy, is a useful metric when it comes to recommendations. One example would be the word recommendations on a virtual keyboard where three suggested words are displayed.
For k=1, we strongly recommend using
Accuracy
to avoid an unnecessary argsort. k < 1 will return 0. and k >= num_classes will return 1.If two or more classes have the same prediction, the classes will be considered in order of lowest to highest indices.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0, 0.5, 0.2]) metric = TopKAccuracy(k=2) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1, weight=1) => 1
- k
Number of top elements to look at for computing accuracy.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- class fedjax.metrics.SequenceTokenCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,), per_position=False)
Bases:
Metric
Metric for token cross entropy loss for a sequence example.
Example:
example = {'y': jnp.array([1, 0, 1])} prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) metric = SequenceTokenCrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.2246635, weight=2) => 0.61233175 per_position_metric = SequenceTokenCrossEntropyLoss(per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1.1711007, 0., 0.05356275], weight=[1., 0., 1.]) => [1.1711007, 0., 0.05356275]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceCrossEntropyLoss(target_key='y', pred_key=None, masked_target_values=(0,))
Bases:
Metric
Metric for total cross entropy loss for a sequence example.
Example:
example = {'y': jnp.array([1, 0, 1])} prediction = jnp.array([[1.2, 0.4], [2.3, 0.1], [0.3, 3.2]]) metric = SequenceCrossEntropyLoss() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=1.2246635, weight=1) => 1.2246635
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTokenAccuracy(target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)
Bases:
Metric
Metric for token accuracy for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 1, 3, 0])} # prediction = [1, 0, 2, 1, 3, 0]. prediction = jnp.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0]]) logits_mask = (0., 0., 0., jnp.NINF) metric = SequenceTokenAccuracy(logits_mask=logits_mask) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=3, weight=5) => 0.6 per_position_metric = SequenceTokenAccuracy(logits_mask=logits_mask, per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- logits_mask
Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.
- Type:
Optional[Tuple[float, …]]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceTokenTopKAccuracy(k, target_key='y', pred_key=None, masked_target_values=(0,), logits_mask=None, per_position=False)
Bases:
Metric
Metric for token top k accuracy for a sequence example.
For more information on the top k accuracy metric, refer to the
TopKAccuracy
docstring.Example:
example = {'y': jnp.array([1, 2, 2, 1, 3, 0])} prediction = jnp.array([[0, 1, 0.5, 0], [1, 0.5, 0, 0], [0.8, 0, 0.7, 0], [0.5, 1, 0, 0], [0, 0.5, 0, 1], [0.5, 0, 0.9, 0]]) logits_mask = (0., 0., 0., jnp.NINF) metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=3, weight=5) => 0.6 per_position_metric = SequenceTokenTopKAccuracy(k=2, logits_mask=logits_mask, per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[1., 0., 1., 1., 0., 0.], weight=[1., 1., 1., 1., 1., 0.]) => [1., 0., 1., 1., 0., 0.]
- k
Number of top elements to look at for computing accuracy.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- logits_mask
Mask of shape [num_classes] to be applied for preds. This is typically used to discount predictions for out-of-vocabulary tokens.
- Type:
Optional[Tuple[float, …]]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceTokenCount(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for count of non masked tokens for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceTokenCount(masked_target_values=(0, 2)) print(metric.evaluate_example(example, prediction)) # SumStat(accum=3) => 3
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceCount(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for count of non masked sequences.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} empty_example = {'y': jnp.array([0, 0, 0, 0, 0, 0, 0])} prediction = jnp.array([]) # Unused. metric = metrics.SequenceCount(masked_target_values=(0, 2)) print(metric.evaluate_example(example, prediction)) # SumStat(accum=1) print(metric.evaluate_example(empty_example, prediction)) # SumStat(accum=0)
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTruncationRate(eos_target_value, target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for truncation rate for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 4])} truncated_example = {'y': jnp.array([1, 2, 2, 3, 3, 3, 3])} prediction = jnp.array([]) # Unused. metric = SequenceTruncationRate(eos_target_value=4) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=0, weight=1) => 0 print(metric.evaluate_example(truncated_example, prediction)) # MeanStat(accum=1, weight=1) => 1
- eos_target_value
Target value denoting end of sequence. Truncated sequences will not have this value.
- Type:
int
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.SequenceTokenOOVRate(oov_target_values, target_key='y', masked_target_values=(0,), per_position=False)
Bases:
Metric
Metric for out-of-vocabulary (OOV) rate for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceTokenOOVRate(oov_target_values=(2,)) print(metric.evaluate_example(example, prediction)) # MeanStat(accum=2, weight=5) => 0.4 per_position_metric = SequenceTokenOOVRate(oov_target_values=(2,), per_position=True) print(per_position_metric.evaluate_example(example, prediction)) # MeanStat(accum=[0., 1., 1., 0., 0., 0., 0.], weight=[1., 1., 1., 1., 1., 0., 0.]) => [0. 1. 1. 0. 0. 0. 0.]
- oov_target_values
Target values denoting out-of-vocabulary values.
- Type:
Tuple[int, …]
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- per_position
Whether to keep output statistic per position or sum across positions for the entire sequence.
- Type:
bool
- class fedjax.metrics.SequenceLength(target_key='y', masked_target_values=(0,))
Bases:
Metric
Metric for length for a sequence example.
Example:
example = {'y': jnp.array([1, 2, 3, 4, 0, 0])} prediction = jnp.array([]) # Unused. metric = SequenceLength() print(metric.evaluate_example(example, prediction)) # MeanStat(accum=4, weight=1) => 4
- target_key
Key name in
example
for target.- Type:
str
- masked_target_values
Target values that should be ignored in computation. This is typically used to ignore padding values in computation.
- Type:
Tuple[int, …]
- class fedjax.metrics.PerDomainMetric(base, num_domains, domain_id_key='domain_id')
Bases:
Metric
Turns a base metric into one that groups results by domain.
This is useful in algorithms such as AgnosticFedAvg.
example
is expected to contain a feature nameddomain_id_key
, which stores the integer domain id in [0, num_domains). PerDomain accumulatesbase
‘sStat
within each domain. If the baseMetric
returns aStat
whose result is of shape X, then theStat
returned by PerDomain will produce a result of shape(num_domains,) + X
. See Batching Stat s for the higher rankStat
mechanism enabling this.Example:
per_domain_accuracy = PerDomain(metrics.Accuracy(), num_domains=3) batch_example = { 'domain_id': jnp.array([0, 0, 1, 2]), 'y': jnp.array([0, 1, 0, 1]) } batch_prediction = jnp.array([[0., 1.], [2., 3.], [4., 5.], [6., 7.]]) print( evaluate_batch(per_domain_accuracy, batch_example, batch_prediction).result()) # [0.5 0. 1. ]
- class fedjax.metrics.ConfusionMatrix(num_classes, target_key='y', pred_key=None)
Bases:
Metric
Metric for making a Confusion Matrix.
A confusion matrix is an nxn matrix often used to describe the performance of a classification model on a set of test data for which the true values are known. The model’s predictions are represented through the columns, and the known data values through the rows. This allows one to view in which areas the model is doing well, as well as where there is room for improvement. For each row in the confusion matrix, if there are a lot of numbers outside of the main diagonal, the model is not doing so well in respect to when it is supposed to output that row’s relative output class.
Theoretical Example:
Predicted P Predicted N Actual P TP FN Actual N FP TN **This is for a binary classification model but the same concept applies to any model with n outputs. Notice that the TPs and TNs will always lie in the main diagonal of the matrix.
Example:
example = {'y': jnp.array(2)} prediction = jnp.array([0., 1., 0.]) metric = ConfusionMatrix(num_classes=3) print(metric.evaluate_example(example, prediction)) # SumStat(accum=DeviceArray([[0., 0., 0.], # [0., 0., 0.], # [0., 1., 0.]], dtype=float32)) => [[0. 0. 0.] # [0. 0. 0.] # [0. 1. 0.]]
- target_key
Key name in
example
for target.- Type:
str
- pred_key
Key name in
prediction
for unnormalized model output pred.- Type:
Optional[str]
- num_classes
Number of output classes of the model. Used to generate a matrix of shape [num_classes, num_classes].
- Type:
int
- fedjax.metrics.unreduced_cross_entropy_loss(targets, preds, is_sparse_targets=True)
Returns unreduced cross entropy loss.
- Return type: