Source code for fedjax.core.sqlite_federated_data

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FederatedData backed by SQLite."""

from typing import Callable, Iterable, Iterator, Optional, Tuple, List
import zlib

from fedjax.core import client_datasets
from fedjax.core import federated_data
from fedjax.core import serialization
import numpy as np
import sqlite3


def decompress_and_deserialize(data: bytes):
  data = zlib.decompress(data)
  return serialization.msgpack_deserialize(data)


[docs]class SQLiteFederatedData(federated_data.FederatedData): """Federated dataset backed by SQLite. The SQLite database should contain a table named "federated_data" created with the following command:: CREATE TABLE federated_data ( client_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, num_examples INTEGER NOT NULL ); where, - `client_id` is the bytes client id. - `data` is the serialized client dataset examples. - `num_examples` is the number of examples in the client dataset. By default we use zlib compressed msgpack blobs for `data` (see decompress_and_deserialize()). """
[docs] @staticmethod def new( path: str, parse_examples: Callable[ [bytes], client_datasets.Examples] = decompress_and_deserialize ) -> 'SQLiteFederatedData': """Opens a federated dataset stored as an SQLite3 database. Args: path: Path to the SQLite database file. parse_examples: Function for deserializing client dataset examples. Returns: SQLite3DataSource. """ connection = sqlite3.connect(path) return SQLiteFederatedData(connection, parse_examples)
[docs] def __init__(self, connection: sqlite3.Connection, parse_examples: Callable[[bytes], client_datasets.Examples], start: Optional[federated_data.ClientId] = None, stop: Optional[federated_data.ClientId] = None, preprocess_client: federated_data .ClientPreprocessor = federated_data.NoOpClientPreprocessor, preprocess_batch: client_datasets .BatchPreprocessor = client_datasets.NoOpBatchPreprocessor): self._connection = connection self._parse_examples = parse_examples self._start = start self._stop = stop self._preprocess_client = preprocess_client self._preprocess_batch = preprocess_batch
def slice( self, start: Optional[federated_data.ClientId] = None, stop: Optional[federated_data.ClientId] = None) -> 'SQLiteFederatedData': start, stop = federated_data.intersect_slice_ranges(self._start, self._stop, start, stop) return SQLiteFederatedData(self._connection, self._parse_examples, start, stop, self._preprocess_client, self._preprocess_batch) def preprocess_client( self, fn: Callable[[federated_data.ClientId, client_datasets.Examples], client_datasets.Examples] ) -> 'SQLiteFederatedData': return SQLiteFederatedData(self._connection, self._parse_examples, self._start, self._stop, self._preprocess_client.append(fn), self._preprocess_batch) def preprocess_batch( self, fn: Callable[[client_datasets.Examples], client_datasets.Examples] ) -> 'SQLiteFederatedData': return SQLiteFederatedData(self._connection, self._parse_examples, self._start, self._stop, self._preprocess_client, self._preprocess_batch.append(fn)) def _range_where(self) -> str: """Builds appropriate WHERE clauses for start/stop ranges.""" if self._start is None and self._stop is None: return '(1)' elif self._start is not None and self._stop is not None: return '(:start <= client_id AND client_id < :stop)' elif self._start is None: return '(client_id < :stop)' else: return '(:start <= client_id)' def num_clients(self) -> int: cursor = self._connection.execute( f'SELECT COUNT(*) FROM federated_data WHERE {self._range_where()};', { 'start': self._start, 'stop': self._stop }) return cursor.fetchone()[0] def client_ids(self) -> Iterator[federated_data.ClientId]: cursor = self._connection.execute( f'SELECT client_id FROM federated_data WHERE {self._range_where()} ORDER BY rowid;', { 'start': self._start, 'stop': self._stop }) while True: result = cursor.fetchone() if result is None: break yield result[0] def client_sizes(self) -> Iterator[Tuple[federated_data.ClientId, int]]: cursor = self._connection.execute( f'SELECT client_id, num_examples FROM federated_data WHERE {self._range_where()} ORDER BY rowid;', { 'start': self._start, 'stop': self._stop }) while True: result = cursor.fetchone() if result is None: break yield tuple(result) def client_size(self, client_id: federated_data.ClientId) -> int: if ((self._start is None or self._start <= client_id) and (self._stop is None or client_id < self._stop)): cursor = self._connection.execute( 'SELECT num_examples FROM federated_data WHERE client_id = ?', [client_id]) result = cursor.fetchone() if result is not None: return result[0] raise KeyError def clients( self ) -> Iterator[Tuple[federated_data.ClientId, client_datasets.ClientDataset]]: for k, v in self._read_clients(): yield k, self._client_dataset(k, v) def _read_clients(self): cursor = self._connection.execute( f'SELECT client_id, data FROM federated_data WHERE {self._range_where()} ORDER BY rowid;', { 'start': self._start, 'stop': self._stop }) while True: result = cursor.fetchone() if result is None: break yield tuple(result) def shuffled_clients( self, buffer_size: int, seed: Optional[int] = None ) -> Iterator[Tuple[federated_data.ClientId, client_datasets.ClientDataset]]: rng = np.random.RandomState(seed) while True: for k, v in client_datasets.buffered_shuffle(self._read_clients(), buffer_size, rng): yield k, self._client_dataset(k, v) def get_clients( self, client_ids: Iterable[federated_data.ClientId] ) -> Iterator[Tuple[federated_data.ClientId, client_datasets.ClientDataset]]: for client_id in client_ids: yield client_id, self.get_client(client_id) def get_client( self, client_id: federated_data.ClientId) -> client_datasets.ClientDataset: if ((self._start is None or self._start <= client_id) and (self._stop is None or client_id < self._stop)): cursor = self._connection.execute( 'SELECT data FROM federated_data WHERE client_id = ?', [client_id]) result = cursor.fetchone() if result is not None: return self._client_dataset(client_id, result[0]) raise KeyError def _client_dataset(self, client_id: federated_data.ClientId, data: bytes) -> client_datasets.ClientDataset: examples = self._preprocess_client(client_id, self._parse_examples(data)) return client_datasets.ClientDataset(examples, self._preprocess_batch)
[docs]class SQLiteFederatedDataBuilder(federated_data.FederatedDataBuilder): """Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples."""
[docs] def __init__(self, path: str): """Initializes SQLiteBuilder by opening a connection and setting up the database with columns. Args: path: Path of file to write to (e.g. /tmp/sqlite_federated_data.sqlite). """ self._connection = sqlite3.connect(path) self._connection.execute(""" CREATE TABLE federated_data ( client_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, num_examples INTEGER NOT NULL );""") self._connection.commit()
def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): self._connection.close() def add_many(self, client_ids_examples: Iterable[Tuple[bytes, client_datasets.Examples]]): def prepare_parameters(ce): client_id = ce[0] examples = ce[1] num_examples = client_datasets.num_examples(examples, validate=True) data = zlib.compress(serialization.msgpack_serialize(examples)) return client_id, data, num_examples client_ids_datas_num_examples = map(prepare_parameters, client_ids_examples) self._connection.executemany('INSERT INTO federated_data VALUES (?, ?, ?);', client_ids_datas_num_examples) self._connection.commit()
class TFFSQLiteClientsIterator: """Iterator over clients stored in TensorFlow Federated (TFF) SQLite tables. TFF uses a "1 example per record" format, where each row in the SQLite table is a key-value pair from the client id to a protocol buffer message of a single example. A custom `parse_examples` function is thus necessary for different datasets in order to parse different example formats. The TFF. SQLite database should contain two tables named "examples" and "client_metadata" created with the following command:: CREATE TABLE examples ( split_name TEXT NOT NULL, client_id TEXT NOT NULL, serialized_example_proto BLOB NOT NULL); CREATE INDEX idx_examples_client_id ON examples (client_id); CREATE INDEX idx_examples_client_id_split ON examples (split_name, client_id); CREATE TABLE client_metadata ( client_id TEXT NOT NULL, split_name TEXT NOT NULL, num_examples INTEGER NOT NULL); CREATE INDEX idx_metadata_client_id ON client_metadata (client_id); where, - `split_name` is the split name (e.g. "train" or "test"). - `client_id` is the client id. - `serialized_example_proto` is a single serialized `tf.train.Example` protocol buffer message. - `num_examples` is the number of examples in the client dataset. """ def __init__(self, path: str, parse_examples: Callable[[List[bytes]], client_datasets.Examples], split_name: str): self._connection = sqlite3.connect(path) self._parse_examples = parse_examples self._split_name = split_name self._client_ids_cursor = self._connection.execute( 'SELECT client_id FROM client_metadata WHERE split_name = ? ORDER BY client_id;', [split_name]) def __del__(self): self._connection.close() def __iter__( self ) -> Iterator[Tuple[federated_data.ClientId, client_datasets.ClientDataset]]: return self def __next__( self) -> Tuple[federated_data.ClientId, client_datasets.ClientDataset]: client_ids_result = self._client_ids_cursor.fetchone() if client_ids_result is None: raise StopIteration client_id = client_ids_result[0] examples_cursor = self._connection.execute( 'SELECT serialized_example_proto FROM examples WHERE split_name = ? AND client_id = ? ORDER BY rowid;', [self._split_name, client_id]) examples = [r[0] for r in examples_cursor.fetchall()] return client_id, client_datasets.ClientDataset( self._parse_examples(examples))