Source code for fedjax.models.stackoverflow

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stack Overflow recurrent models."""

from typing import Optional

from fedjax.core import metrics
from fedjax.core import models

import haiku as hk
import jax.numpy as jnp


[docs]def create_lstm_model(vocab_size: int = 10000, embed_size: int = 96, lstm_hidden_size: int = 670, lstm_num_layers: int = 1, share_input_output_embeddings: bool = False, expected_length: Optional[float] = None) -> models.Model: """Creates LSTM language model. Word-level language model for Stack Overflow. Defaults to the model used in: Adaptive Federated Optimization Sashank Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konečný, Sanjiv Kumar, H. Brendan McMahan. https://arxiv.org/abs/2003.00295 Args: vocab_size: The number of possible output words. This does not include special tokens like PAD, BOS, EOS, or OOV. embed_size: Embedding size for each word. lstm_hidden_size: Hidden size for LSTM cells. lstm_num_layers: Number of LSTM layers. share_input_output_embeddings: Whether to share the input embeddings with the output logits. expected_length: Expected average sentence length used to scale the training loss down by `1. / expected_length`. This constant term is used so that the total loss over all the words in a sentence can be scaled down to per word cross entropy values by a constant factor instead of dividing by number of words which can vary across batches. Defaults to no scaling. Returns: Model. """ # TODO(jaero): Replace these with direct references from dataset. pad = 0 bos = 1 eos = 2 oov = vocab_size + 3 full_vocab_size = vocab_size + 4 # We do not guess EOS, and if we guess OOV, it's treated as a mistake. logits_mask = [0. for _ in range(full_vocab_size)] for i in (pad, bos, eos, oov): logits_mask[i] = -jnp.inf logits_mask = tuple(logits_mask) def forward_pass(batch): x = batch['x'] # [time_steps, batch_size, ...]. x = jnp.transpose(x) # [time_steps, batch_size, embed_dim]. embedding_layer = hk.Embed(full_vocab_size, embed_size) embeddings = embedding_layer(x) lstm_layers = [] for _ in range(lstm_num_layers): lstm_layers.extend([ hk.LSTM(hidden_size=lstm_hidden_size), jnp.tanh, # Projection changes dimension from lstm_hidden_size to embed_size. hk.Linear(embed_size) ]) rnn_core = hk.DeepRNN(lstm_layers) initial_state = rnn_core.initial_state(batch_size=embeddings.shape[1]) # [time_steps, batch_size, hidden_size]. output, _ = hk.static_unroll(rnn_core, embeddings, initial_state) if share_input_output_embeddings: output = jnp.dot(output, jnp.transpose(embedding_layer.embeddings)) output = hk.Bias(bias_dims=[-1])(output) else: output = hk.Linear(full_vocab_size)(output) # [batch_size, time_steps, full_vocab_size]. output = jnp.transpose(output, axes=(1, 0, 2)) return output def train_loss(batch, preds): """Returns total loss per sentence optionally scaled down to token level.""" targets = batch['y'] per_token_loss = metrics.unreduced_cross_entropy_loss(targets, preds) # Don't count padded values in loss. per_token_loss *= targets != pad sentence_loss = jnp.sum(per_token_loss, axis=-1) if expected_length is not None: return sentence_loss * (1. / expected_length) return sentence_loss transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch={ 'x': jnp.zeros((1, 1), dtype=jnp.int32), 'y': jnp.zeros((1, 1), dtype=jnp.int32), }, train_loss=train_loss, eval_metrics={ 'accuracy_in_vocab': metrics.SequenceTokenAccuracy( masked_target_values=(pad, eos), logits_mask=logits_mask), 'accuracy_no_eos': metrics.SequenceTokenAccuracy(masked_target_values=(pad, eos)), 'num_tokens': metrics.SequenceTokenCount(masked_target_values=(pad,)), 'sequence_length': metrics.SequenceLength(masked_target_values=(pad,)), 'sequence_loss': metrics.SequenceCrossEntropyLoss(masked_target_values=(pad,)), 'token_loss': metrics.SequenceTokenCrossEntropyLoss( masked_target_values=(pad,)), 'token_oov_rate': metrics.SequenceTokenOOVRate( oov_target_values=(oov,), masked_target_values=(pad,)), 'truncation_rate': metrics.SequenceTruncationRate( eos_target_value=eos, masked_target_values=(pad,)), })