# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Federated Shakespeare."""
import functools
from typing import Optional, Tuple
from fedjax.core import client_datasets
from fedjax.core import federated_data
from fedjax.core import sqlite_federated_data
from fedjax.datasets import downloads
import numpy as np
SPLITS = ('train', 'test')
[docs]def cite():
"""Returns BibTeX citation for the dataset."""
return """@inproceedings{mcmahan2017communication,
title={Communication-efficient learning of deep networks from
decentralized data},
author={McMahan, Brendan and Moore, Eider and Ramage, Daniel and
Hampson, Seth and y Arcas, Blaise Aguera},
booktitle={Artificial Intelligence and Statistics},
pages={1273--1282},
year={2017},
organization={PMLR}
}"""
[docs]def load_split(split: str,
mode: str = 'sqlite',
cache_dir: Optional[str] = None) -> federated_data.FederatedData:
"""Loads a shakespeare split.
Features:
- snippets: [N] bytes array of snippet text.
Args:
split: Name of the split. One of SPLITS.
mode: 'sqlite'.
cache_dir: Directory to cache files in 'sqlite' mode.
Returns:
FederatedData.
"""
if split not in SPLITS:
raise ValueError(f'Invalid split={split!r}')
if cache_dir is not None and mode != 'sqlite':
raise ValueError('Caching locally is only supported in "sqlite" mode')
if mode == 'sqlite':
path = downloads.maybe_download(
f'https://storage.googleapis.com/gresearch/fedjax/shakespeare/shakespeare_{split}.sqlite',
cache_dir)
return sqlite_federated_data.SQLiteFederatedData.new(path)
else:
raise ValueError(f'Unsupported mode={mode!r}')
[docs]def load_data(
sequence_length: int = 80,
mode: str = 'sqlite',
cache_dir: Optional[str] = None
) -> Tuple[federated_data.FederatedData, federated_data.FederatedData]:
"""Loads preprocessed shakespeare splits.
Preprocessing is done using :meth:`fedjax.FederatedData.preprocess_client`
and :meth:`preprocess_client`.
Features (M below is possibly different from N in load_split):
- x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
Args:
sequence_length: The fixed sequence length after preprocessing.
mode: 'sqlite'.
cache_dir: Directory to cache files in 'sqlite' mode.
Returns:
A (train, held_out, test) tuple of federated data.
"""
train = load_split('train', mode, cache_dir)
test = load_split('test', mode, cache_dir)
preprocess = functools.partial(
preprocess_client, sequence_length=sequence_length)
return (train.preprocess_client(preprocess),
test.preprocess_client(preprocess))
def _build_look_up_table(vocab: bytes,
num_reserved: int) -> Tuple[np.ndarray, int]:
"""Builds a look-up table from a byte to its integer label.
Args:
vocab: bytes object listing the byte values to include in the vocabulary.
The byte vocab[i] is assigned label `num_reserved + i`. If the same byte
occurs multiple times, the index of the last occurrence is used.
num_reserved: Number of labels to reserve in the beginning of the integer
label domain. Bytes in `vocab` will not be mapped to [0, num_reserved).
Returns:
(table, vocab_size) tuple. `table` is simply a [256] ndarray containing the
integer label for each byte. Bytes not in `vocab` are mapped to `vocab_size
- 1`.
"""
oov = num_reserved + len(vocab)
vocab_size = oov + 1
table = np.full([256], oov, dtype=np.int32)
for i, c in enumerate(vocab):
table[c] = num_reserved + i
return table, vocab_size
# Vocabulary re-used from the Federated Learning for Text Generation tutorial.
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
TABLE, VOCAB_SIZE = _build_look_up_table(
b'dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r',
num_reserved=3)
OOV = VOCAB_SIZE - 1
# Reserved labels.
PAD = 0
BOS = 1
EOS = 2
[docs]def preprocess_client(client_id: federated_data.ClientId,
examples: client_datasets.Examples,
sequence_length: int) -> client_datasets.Examples:
"""Turns snippets into sequences of integer labels.
Features (M below is possibly different from N in load_split):
- x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
All snippets in a client dataset are first joined into a single sequence (with
BOS/EOS added), and then split into pairs of `sequence_length` chunks for
language model training. For example, with sequence_length=3,
`[b'ABCD', b'E']` becomes::
Input sequences: [[BOS, A, B], [C, D, EOS], [BOS, E, PAD]]
Output seqeunces: [[A, B, C], [D, EOS, BOS], [E, EOS, PAD]]
Note: This is not equivalent to the `TensorFlow Federated text generation tutorial <https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation#load_and_preprocess_the_federated_shakespeare_data>`_
(The processing logic there loses ~1/sequence_length portion of the tokens).
Args:
client_id: Not used.
examples: Unprocessed examples (e.g. from `load_split()`).
sequence_length: The fixed sequence length after preprocessing.
Returns:
Processed examples.
"""
del client_id
snippets = examples['snippets']
# Join all snippets into a single label sequence.
joined_length = sum(len(i) + 2 for i in snippets)
joined = np.zeros([joined_length], dtype=np.int32)
offset = 0
for i in snippets:
joined[offset] = BOS
joined[offset + 1:offset + 1 + len(i)] = TABLE[list(i)]
joined[offset + 1 + len(i)] = EOS
offset += len(i) + 2
# Split into input/output sequences of size `sequence_length`.
padded_length = ((joined_length - 1 + sequence_length - 1) //
sequence_length * sequence_length)
input_labels = np.full([padded_length], PAD, dtype=np.int32)
input_labels[:joined_length - 1] = joined[:-1]
output_labels = np.full([padded_length], PAD, dtype=np.int32)
output_labels[:joined_length - 1] = joined[1:]
return {
'x': input_labels.reshape([-1, sequence_length]),
'y': output_labels.reshape([-1, sequence_length])
}