# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FederatedData interface for providing access to a federated dataset."""
import abc
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple
from fedjax.core import client_datasets
import numpy as np
# A client id is simply some binary bytes.
ClientId = bytes
[docs]class ClientPreprocessor:
"""A chain of preprocessing functions on all examples of a client dataset.
This is very similar to :class:`fedjax.BatchPreprocessor`, with the main
difference being that ClientPreprocessor also takes ``client_id`` as input.
See the discussion in :class:`fedjax.BatchPreprocessor` regarding when to
use which.
"""
[docs] def __init__(self,
fns: Iterable[Callable[[ClientId, client_datasets.Examples],
client_datasets.Examples]] = ()):
self._fns = tuple(fns)
[docs] def __call__(self, client_id: ClientId,
examples: client_datasets.Examples) -> client_datasets.Examples:
if not self._fns:
return examples
# Make a copy to guard against fns that modify their input.
out = dict(examples)
for f in self._fns:
out = f(client_id, out)
client_datasets.assert_consistent_rows(out)
return out
[docs] def append(
self, fn: Callable[[ClientId, client_datasets.Examples],
client_datasets.Examples]
) -> 'ClientPreprocessor':
"""Creates a new ClientPreprocessor with fn added to the end."""
return ClientPreprocessor(self._fns + (fn,))
def __str__(self) -> str:
return f'ClientPreprocessor({self._fns})'
def __repr__(self) -> str:
return str(self)
# A common default preprocessor that does nothing.
NoOpClientPreprocessor = ClientPreprocessor()
[docs]class FederatedData(abc.ABC):
"""FederatedData interface for providing access to a federated dataset.
A FederatedData object serves as a mapping from client ids to client
datasets and client metadata.
**Access methods with better I/O efficiency**
For large federated datasets, it is not feasible to load all client datasets
into memory at once (whereas loading a single client dataset is assumed to be
feasible). Different implementations exist for different on disk storage
formats. Since sequential read is much faster than random read for most
storage technologies, FederatedData provides two types of methods for
accessing client datasets,
1. :meth:`clients` and :meth:`shuffled_clients` are sequential read
friendly, and thus recommended whenever appropriate.
2. :meth:`get_clients` requires random read, but prefetching is possible.
This should be preferred over :meth:`get_client`.
3. :meth:`get_client` is usually the slowest way of accessing client
datasets, and is mostly intended for interactive exploration of a small
number of clients.
**Preprocessing**
:class:`~fedjax.ClientDataset` produced by FederatedData can hold a
:class:`~fedjax.BatchPreprocessor`, customizable via :meth:`preprocess_batch`.
Additionally, another "client" level :class:`ClientPreprocessor`, customizable
via :meth:`preprocess_client`, can be used to apply transformations on
examples from the entire client dataset before a
:class:`~fedjax.ClientDataset` is constructed.
"""
[docs] @abc.abstractmethod
def slice(self,
start: Optional[ClientId] = None,
stop: Optional[ClientId] = None) -> 'FederatedData':
"""Returns a new FederatedData restricted to client ids in the given range.
The returned FederatedData includes clients whose ids are,
- Greater than or equal to ``start`` when ``start`` is not None;
- Less than ``stop`` when ``stop`` is not None.
Args:
start: Start of client id range.
stop: Stop of client id range.
Returns:
FederatedData.
"""
# Client metadata access.
[docs] @abc.abstractmethod
def num_clients(self) -> int:
"""Returns the number of clients.
If it is too expensive or otherwise impossible to obtain the result, an
implementation may raise an exception.
"""
[docs] @abc.abstractmethod
def client_ids(self) -> Iterator[ClientId]:
"""Returns an iterator of client ids as bytes.
There is no requirement on the order of iteration.
"""
[docs] @abc.abstractmethod
def client_sizes(self) -> Iterator[Tuple[ClientId, int]]:
"""Returns an iterator of all (client id, client size) pairs.
This is often more efficient than making multiple :meth:`client_size` calls.
There is no requirement on the order of iteration.
"""
[docs] @abc.abstractmethod
def client_size(self, client_id: ClientId) -> int:
"""Returns the number of examples in a client dataset."""
# Client access.
[docs] @abc.abstractmethod
def clients(self) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
"""Iterates over clients in a deterministic order.
Implementation can choose whatever order that makes iteration efficient.
"""
[docs] @abc.abstractmethod
def shuffled_clients(
self,
buffer_size: int,
seed: Optional[int] = None
) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
"""Iterates over clients with a repeated buffered shuffling.
Shuffling should use a buffer size of at least ``buffer_size`` clients.
The iteration should repeat forever, with usually a different order in each
pass.
Args:
buffer_size: Buffer size for shuffling.
seed: Optional random number generator seed.
Returns:
Iterator.
"""
[docs] @abc.abstractmethod
def get_clients(
self, client_ids: Iterable[ClientId]
) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
"""Gets multiple clients in order with one call.
Clients are returned in the order of ``client_ids``.
Args:
client_ids: Client ids to load.
Returns:
Iterator.
"""
[docs] @abc.abstractmethod
def get_client(self, client_id: ClientId) -> client_datasets.ClientDataset:
"""Gets one single client dataset.
Prefer :meth:`clients`, :meth:`shuffled_clients`, or :meth:`get_clients`
when possible.
Args:
client_id: Client id to load.
Returns:
The corresponding ClientDataset.
"""
# Preprocessing.
[docs] @abc.abstractmethod
def preprocess_client(
self, fn: Callable[[ClientId, client_datasets.Examples],
client_datasets.Examples]
) -> 'FederatedData':
"""Registers a preprocessing function to be called on all examples of a client before passing them to construct a ClientDataset."""
[docs] @abc.abstractmethod
def preprocess_batch(
self, fn: Callable[[client_datasets.Examples], client_datasets.Examples]
) -> 'FederatedData':
"""Registers a preprocessing function to be called after batching in ClientDatasets."""
# Functions for treating a federated dataset as a single centralized dataset.
[docs]def shuffle_repeat_batch_federated_data(
fd: FederatedData,
batch_size: int,
client_buffer_size: int,
example_buffer_size: int,
seed: Optional[int] = None) -> Iterator[client_datasets.Examples]:
"""Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.
Shuffling is done using two levels of buffered shuffling, first at the client
level, then at the example level.
This produces an infinite stream of batches. itertools.islice() can be used to
cap the number of batches, if so desired.
Args:
fd: Federated dataset.
batch_size: Desired batch size.
client_buffer_size: Buffer size for client level shuffling.
example_buffer_size: Buffer size for example level shuffling.
seed: Optional RNG seed.
Yields:
Batches of preprocessed examples.
"""
rng = np.random.RandomState(seed)
datasets = (client_dataset for _, client_dataset in fd.shuffled_clients(
client_buffer_size, rng.randint(1 << 32)))
yield from client_datasets.buffered_shuffle_batch_client_datasets(
datasets, batch_size=batch_size, buffer_size=example_buffer_size, rng=rng)
[docs]def padded_batch_federated_data(fd: FederatedData,
hparams: Optional[
client_datasets.PaddedBatchHParams] = None,
**kwargs) -> Iterator[client_datasets.Examples]:
"""Padded batch all client datasets, useful for evaluation on the entire federated dataset.
Args:
fd: Federated dataset.
hparams: See :func:`fedjax.padded_batch_client_datasets`.
**kwargs: See :func:`fedjax.padded_batch_client_datasets`.
Yields:
Batches of preprocessed examples.
"""
datasets = (client_dataset for _, client_dataset in fd.clients())
yield from client_datasets.padded_batch_client_datasets(
datasets, hparams, **kwargs)
def intersect_slice_ranges(
current_start: Optional[ClientId], current_stop: Optional[ClientId],
new_start: Optional[ClientId], new_stop: Optional[ClientId]
) -> Tuple[Optional[ClientId], Optional[ClientId]]:
"""Intersects the current slice range and the new slice range.
This is a helper function for FederatedData implementations for ensuring
slicing does not enlarge the range of client ids.
Args:
current_start: Current start of the slice range.
current_stop: Current stop of the slice range.
new_start: New start of the slice range.
new_stop: New stop of the slice range.
Returns:
Normalized slice range that is the intersection of the two input ranges.
"""
if current_start is not None:
if new_start is None:
new_start = current_start
else:
new_start = max(current_start, new_start)
if current_stop is not None:
if new_stop is None:
new_stop = current_stop
else:
new_stop = min(current_stop, new_stop)
return new_start, new_stop
[docs]class RepeatableIterator:
"""Repeats a base iterable after the end of the first pass of iteration.
Because this is a stateful object, it is not thread safe, and all usual
caveats about accessing the same iterator at different locations apply. For
example, if we make two map calls to the same RepeatableIterator, we must make
sure we do not interleave `next()` calls on these. For example, the following
is safe because we finish iterating on m1 before starting to iterate on m2., ::
it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
# We finish iterating on m1 before starting to iterate on m2.
print(list(m1), list(m2))
# [1, 2, 3, 4] [0, 1, 4, 9]
Whereas interleaved access leads to confusing results, ::
it = RepeatableIterator(range(4))
m1 = map(lambda x: x + 1, it)
m2 = map(lambda x: x * x, it)
print(next(m1), next(m2))
# 1 1
print(next(m1), next(m2))
# 3 9
print(next(m1), next(m2))
# StopIteration!
In the first pass of iteration, values fetched from the base iterator will be
copied onto an internal buffer (except for a few builtin containers where
copying is unnecessary). When each pass of iteration finishes (i.e. when
__next__() raises StopIteration), the iterator resets itself to the start of
the buffer, thus allowing a subsequent pass of repeated iteration.
In most cases, if repeated iterations are required, it is sufficient to simply
copy values from an iterator into a list. However, sometimes an iterator
produces values via potentially expensive I/O operations (e.g. loading client
datasets), RepeatableIterator can interleave I/O and JAX compute to decrease
accelerator idle time in this case.
"""
def __init__(self, base: Iterable[Any]):
if any(
isinstance(base, container)
for container in (list, tuple, dict, str, bytes)):
# No copying for builtin containers that are already repeatable.
self._first_pass = False
self._iter = iter(base)
self._buf = base
else:
# General case, copying required.
self._first_pass = True
self._iter = iter(base)
self._buf = []
def __iter__(self) -> Iterator[Any]:
return self
def __next__(self) -> Any:
try:
value = next(self._iter)
except StopIteration:
if self._first_pass:
self._first_pass = False
self._iter = iter(self._buf)
raise
if self._first_pass:
self._buf.append(value)
return value
[docs]class SubsetFederatedData(FederatedData):
"""A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.
This is useful when we wish to create a smaller FederatedData out of arbitrary
client ids, where slicing is not possible.
"""
[docs] def __init__(self,
base: FederatedData,
client_ids: Iterable[ClientId],
validate=True):
"""Initializes the subset federated dataset.
Args:
base: Base concrete FederatedData.
client_ids: Client ids to include in the subset. All client ids must be in
base.client_ids(), otherwise the behavior of SubsetFederatedData is
undefined when validate=False.
validate: Whether to validate client ids.
"""
self._base = base
if not isinstance(client_ids, set):
client_ids = set(client_ids)
if validate:
bad_client_ids = client_ids.difference(base.client_ids())
if bad_client_ids:
raise ValueError('Some client ids are not in the base FederatedData, '
f'showing up to 10: {sorted(bad_client_ids)[:10]}')
self._client_ids = client_ids
def slice(self,
start: Optional[ClientId] = None,
stop: Optional[ClientId] = None) -> FederatedData:
if start is None and stop is None:
client_ids = self._client_ids
elif start is None:
client_ids = set(i for i in self._client_ids if i < stop)
elif stop is None:
client_ids = set(i for i in self._client_ids if i >= start)
else:
client_ids = set(i for i in self._client_ids if start <= i and i < stop)
return SubsetFederatedData(
self._base.slice(start, stop), client_ids, validate=False)
def num_clients(self) -> int:
return len(self._client_ids)
def client_ids(self) -> Iterator[ClientId]:
# Ids are sorted for deterministic iteration order.
return iter(sorted(self._client_ids))
def client_sizes(self) -> Iterator[Tuple[ClientId, int]]:
for client_id, size in self._base.client_sizes():
if client_id in self._client_ids:
yield client_id, size
def client_size(self, client_id: ClientId) -> int:
if client_id not in self._client_ids:
raise KeyError
return self._base.client_size(client_id)
def clients(self) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
# Ids are sorted for deterministic iteration order.
yield from self.get_clients(sorted(self._client_ids))
def shuffled_clients(
self,
buffer_size: int,
seed: Optional[int] = None
) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
rng = np.random.RandomState(seed)
while True:
for client_id, dataset in client_datasets.buffered_shuffle(
self.clients(), buffer_size, rng):
yield client_id, dataset
def get_clients(
self, client_ids: Iterable[ClientId]
) -> Iterator[Tuple[ClientId, client_datasets.ClientDataset]]:
for client_id, dataset in self._base.get_clients(client_ids):
if client_id not in self._client_ids:
raise KeyError
yield client_id, dataset
def get_client(self, client_id: ClientId) -> client_datasets.ClientDataset:
if client_id not in self._client_ids:
raise KeyError
return self._base.get_client(client_id)
def preprocess_client(
self, fn: Callable[[ClientId, client_datasets.Examples],
client_datasets.Examples]
) -> FederatedData:
return SubsetFederatedData(
self._base.preprocess_client(fn), self._client_ids, validate=False)
def preprocess_batch(
self, fn: Callable[[client_datasets.Examples], client_datasets.Examples]
) -> FederatedData:
return SubsetFederatedData(
self._base.preprocess_batch(fn), self._client_ids, validate=False)
[docs]class FederatedDataBuilder(abc.ABC):
"""FederatedDataBuilder interface.
To be implemented as a context manager for building file formats from pairs of
client IDs and client NumPy examples.
It is relevant to note that the add method below does not specify any raised
exceptions. One could imagine some formats where add can fail in some way:
out-of-order or duplicate inputs, remote files and network failures,
individual entries too big for a format, etc. In order to address this we let
implementations throw whatever they see relevant and fit to their particular
use cases. The same is relevant when it comes to the __init__, __enter__, and
__exit__ methods, where implementations are left with the responsibility of
raising exceptions as they see fit to their particular use cases. For example,
if an invalid file path is passed, or there were any issues finalizing the
builder, etc.
Eg of end behavior::
with FederatedDataBuilder(path) as builder:
builder.add(b'k1', np.array([b'v1'], dtype=np.object))
builder.add(b'k2', np.array([b'v2'], dtype=np.object))
"""
@abc.abstractmethod
def __enter__(self):
"""Assigns the variable defined after 'as' in the with statement to self.
By returning self the required functionality is kept within the same class
so that one can call the add method defined below inside the with block.
Returns:
self
"""
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Finalizes the builder once it leaves the 'with' block.
Args:
exc_type: indicates class of exception.
exc_value: indicates type of exception.
exc_traceback: traceback is a report which has all of the information
needed to solve the exception.
"""
[docs] @abc.abstractmethod
def add_many(self,
client_ids_examples: Iterable[Tuple[bytes,
client_datasets.Examples]]):
"""Bulk adds multiple client IDs and client NumPy examples pairs to file format.
Args:
client_ids_examples: Iterable of tuples of client id and examples.
"""