Source code for fedjax.datasets.cifar100

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Federated cifar100."""

import os.path
from typing import List, Optional, Tuple

from fedjax.core import client_datasets
from fedjax.core import federated_data
from fedjax.core import sqlite_federated_data
from fedjax.core import util
from fedjax.datasets import downloads

import numpy as np

tf = util.import_tf()

SPLITS = ('train', 'test')
_TFF_SQLITE_COMPRESSED_HEXDIGEST = '23d3916c9caa33395737ee560cc7cb77bbd05fc7b73647ee7be3a7e764172939'
_TFF_SQLITE_COMPRESSED_NUM_BYTES = 153019920
_FEDJAX_SQLITE_HEXDIGEST = {
    'train': 'a4dc2f4ac4c9b6e7bd5234a1b568389384c00c903dcd34001c3cf50a4a81c713',
    'test': '74416c6ee8f41b1086f0e0e3c5289ac6df5a641100df48e8b583411a55de891f'
}
_FEDJAX_SQLITE_NUM_BYTES = {'train': 140521472, 'test': 28135424}


[docs]def cite(): """Returns BibTeX citation for the dataset.""" return """@TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, title = {Learning multiple layers of features from tiny images}, institution = {}, year = {2009} }"""
def _parse_tf_examples(vs: List[bytes]) -> client_datasets.Examples: tf_examples = tf.io.parse_example( vs, features={ 'coarse_label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64), 'image': tf.io.FixedLenFeature(shape=(32, 32, 3), dtype=tf.int64), 'label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64), }) tf_examples['image'] = tf.cast(tf_examples['image'], dtype=tf.uint8) return tf.nest.map_structure(lambda t: t.numpy(), tf_examples)
[docs]def load_split(split: str, mode: str = 'sqlite', cache_dir: Optional[str] = None) -> federated_data.FederatedData: """Loads a cifar100 split. Features: - image: [N, 32, 32, 3] uint8 pixels. - coarse_label: [N] int64 coarse labels in the range [0, 20). - label: [N] int64 labels in the range [0, 100). Args: split: Name of the split. One of SPLITS. mode: 'sqlite'. cache_dir: Directory to cache files in 'sqlite' mode. Returns: FederatedData. """ if split not in SPLITS: raise ValueError(f'Invalid split={split!r}') if cache_dir is not None and mode != 'sqlite': raise ValueError('Caching locally is only supported in "sqlite" mode') if mode == 'sqlite': # Download and decompress LZMA compressed SQLite file from TFF. compressed_path = downloads.maybe_download( 'https://storage.googleapis.com/tff-datasets-public/cifar100.sqlite.lzma', cache_dir) # Validate that the original TFF file has not been updated. downloads.validate_file(compressed_path, _TFF_SQLITE_COMPRESSED_NUM_BYTES, _TFF_SQLITE_COMPRESSED_HEXDIGEST) decompressed_path = downloads.maybe_lzma_decompress(compressed_path) path = os.path.join( os.path.dirname(decompressed_path), f'federated_cifar100_{split}.sqlite') if os.path.exists(path): downloads.log(f'Reusing cached file {path!r}') else: # Convert TFF "1 example per record" to "all client examples per record". with sqlite_federated_data.SQLiteFederatedDataBuilder(path) as builder: client_ids_examples = map( lambda c: (c[0], c[1].all_examples()), sqlite_federated_data.TFFSQLiteClientsIterator( decompressed_path, _parse_tf_examples, split)) builder.add_many(client_ids_examples) # Validate that the final produced SQLite is consistent. downloads.validate_file(path, _FEDJAX_SQLITE_NUM_BYTES[split], _FEDJAX_SQLITE_HEXDIGEST[split]) return sqlite_federated_data.SQLiteFederatedData.new(path) else: raise ValueError(f'Unsupported mode={mode!r}')
[docs]def load_data( mode: str = 'sqlite', cache_dir: Optional[str] = None ) -> Tuple[federated_data.FederatedData, federated_data.FederatedData]: """Loads partially preprocessed cifar100 splits. Features: - x: [N, 32, 32, 3] uint8 pixels. - y: [N] int32 labels in the range [0, 100). Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:: import functools from fedjax.datasets import cifar100 # Load partially preprocessed splits. train, test = cifar100.load_data() # Preprocessing for training. train_for_train = train.preprocess_batch( functools.partial(preprocess_batch, is_train=True)) # Preprocessing for eval. train_for_eval = train.preprocess_batch( functools.partial(preprocess_batch, is_train=False)) test = test.preprocess_batch( functools.partial(preprocess_batch, is_train=False)) Features after this preprocessing: - x: [N, 32, 32, 3] float32 preprocessed pixels. - y: [N] int32 labels in the range [0, 100). Alternatively, you can apply the same preprocessing as TensorFlow Federated following tff.simulation.baselines.cifar100.create_image_classification_task. For example,:: from fedjax.datasets import cifar100 train, test = cifar100.load_data() train = train.preprocess_batch(preprocess_batch_tff) test = test.preprocess_batch(preprocess_batch_tff) Features after this preprocessing: - x: [N, 24, 24, 3] float32 preprocessed pixels. - y: [N] int32 labels in the range [0, 100). Note: ``preprocess_batch`` and ``preprocess_batch_tff`` are just convenience wrappers around :meth:`preprocess_image` and :meth:`preprocess_image_tff`, respectively, for use with :meth:`fedjax.FederatedData.preprocess_batch`. Args: mode: 'sqlite'. cache_dir: Directory to cache files in 'sqlite' mode. Returns: A (train, test) tuple of federated data. """ train = load_split('train', mode, cache_dir) test = load_split('test', mode, cache_dir) return (train.preprocess_client(preprocess_client), test.preprocess_client(preprocess_client))
def preprocess_client( client_id: federated_data.ClientId, examples: client_datasets.Examples) -> client_datasets.Examples: del client_id return {'x': examples['image'], 'y': examples['label'].astype(np.int32)} # Mean and stddev computed assuming RGB values lie in [0,1]. CIFAR100_PIXELS_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32) CIFAR100_PIXELS_INVERSE_STDDEV = ( 1 / np.array([0.2023, 0.1994, 0.2010], dtype=np.float32))
[docs]def preprocess_image(image: np.ndarray, is_train: bool) -> np.ndarray: """Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing. Preprocessing procedure and values taken from `pytorch-cifar <https://github.com/kuangliu/pytorch-cifar/blob/49b7aa97b0c12fe0d4054e670403a16b6b834ddd/main.py#L30-L40>`_. Args: image: [N, 32, 32, 3] uint8 pixels. is_train: Whether we are preprocessing for training or eval. Returns: Processed [N, 32, 32, 3] float32 pixels. """ if is_train: # Pad 4 zero pixels on all sides and then randomly crop back to (32, 32, 3). num_paddings = 4 image = np.pad( image, [(0, 0), (num_paddings, num_paddings), (num_paddings, num_paddings), (0, 0)], mode='constant') # Start offsets for cropping. i, j = np.random.randint(num_paddings * 2, size=[2]) image = image[:, i:i + 32, j:j + 32, :] # Random horizontal flip. if np.random.randint(2): image = np.flip(image, axis=-2) # Center and normalize. image = ((image.astype(np.float32) / 255 - CIFAR100_PIXELS_MEAN) * CIFAR100_PIXELS_INVERSE_STDDEV) return image
def preprocess_batch(examples: client_datasets.Examples, is_train: bool) -> client_datasets.Examples: return {'x': preprocess_image(examples['x'], is_train), 'y': examples['y']} def preprocess_image_tff(image: np.ndarray, crop_height: int, crop_width: int, distort: bool) -> np.ndarray: """Augments and preprocesses CIFAR-100 images following TensorFlow Federated.""" if crop_height < 1 or crop_width < 1 or crop_height > 32 or crop_width > 32: raise ValueError('The crop_height and crop_width must be between 1 and 32.') if distort: # Distort via random crops and flips. # Start offsets for cropping should match tf.image.random_crop. crop_shape = np.array((crop_height, crop_width, 3)).astype(np.int32) shape = np.array(image.shape).astype(np.int32)[1:] limit = shape - crop_shape + 1 offset = np.random.uniform( high=np.iinfo(limit.dtype).max, size=shape.shape).astype( limit.dtype) % limit begin_i, begin_j, _ = offset end_i, end_j, _ = offset + crop_shape image = image[:, begin_i:end_i, begin_j:end_j, :] # Random horizontal flip should match tf.image.random_flip_left_right. if np.random.randint(2): image = np.flip(image, axis=-2) else: # Center crop should match tf.image.resize_with_crop_or_pad. height_offset = (32 - crop_height) // 2 width_offset = (32 - crop_width) // 2 image = image[:, height_offset:height_offset + crop_height, width_offset:width_offset + crop_width, :] image = image.astype(np.float32) # Normalize should match tf.image.per_image_standardization. num_pixels = np.prod(image.shape[-3:], dtype=np.float32) image_mean = np.mean(image, axis=(-1, -2, -3), keepdims=True) image_std = np.std(image, axis=(-1, -2, -3), keepdims=True) image_adjusted_std = np.maximum(image_std, np.sqrt(num_pixels)) return (image - image_mean) / image_adjusted_std def preprocess_batch_tff(examples: client_datasets.Examples, crop_height: int = 24, crop_width: int = 24, distort: bool = False) -> client_datasets.Examples: """Preprocesses CIFAR-100 examples following TensorFlow Federated. Preprocessing procedure and values taken from `tff.simulation.baselines.cifar100.create_image_classification_task <https://www.tensorflow.org/federated/api_docs/python/tff/simulation/baselines/cifar100/create_image_classification_task>`_. Args: examples: Batch of examples containing - x: [N, 32, 32, 3] uint8 pixels. - y: [N] int32 labels in the range [0, 100). crop_height: Desired height for cropping images. Must be between 1 and 32. Defaults to the value used in TensorFlow Federated. crop_width: Desired width for cropping images. Must be between 1 and 32. Defaults to the value used in TensorFlow Federated. distort: Whether to distort the image via random crops and flips. Defaults to center cropping. Returns: Processed batch of examples containing - x: [N, ``crop_height``, ``crop_width``, 3] float32 pixels. - y: [N] int32 labels in the range [0, 100). """ return { 'x': preprocess_image_tff(examples['x'], crop_height, crop_width, distort), 'y': examples['y'], }