The Transformer Decoder: Ungraded Lab Notebook
In this notebook, you'll explore the transformer decoder and how to implement it with Trax.
Background
In the last lecture notebook, you saw how to translate the mathematics of attention into NumPy code. Here, you'll see how multi-head causal attention fits into a GPT-2 transformer decoder, and how to build one with Trax layers. In the assignment notebook, you'll implement causal attention from scratch, but here, you'll exploit the handy-dandy tl.CausalAttention()
layer.
The schematic below illustrates the components and flow of a transformer decoder. Note that while the algorithm diagram flows from the bottom to the top, the overview and subsequent Trax layer codes are top-down.
Imports
import sys
import os
import time
import numpy as np
import gin
import textwrap
wrapper = textwrap.TextWrapper(width=70)
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)
INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0
Sentence gets embedded, add positional encoding
Embed the words, then create vectors representing each word's position in each sentence = range(max_len)
, where max_len
= )
def PositionalEncoder(vocab_size, d_model, dropout, max_len, mode):
"""Returns a list of layers that:
1. takes a block of text as input,
2. embeds the words in that text, and
3. adds positional encoding,
i.e. associates a number in range(max_len) with
each word in each sentence of embedded input text
The input is a list of tokenized blocks of text
Args:
vocab_size (int): vocab size.
d_model (int): depth of embedding.
dropout (float): dropout rate (how much to drop out).
max_len (int): maximum symbol length for positional encoding.
mode (str): 'train' or 'eval'.
"""
# Embedding inputs and positional encoder
return [
# Add embedding layer of dimension (vocab_size, d_model)
tl.Embedding(vocab_size, d_model),
# Use dropout with rate and mode specified
tl.Dropout(rate=dropout, mode=mode),
# Add positional encoding layer with maximum input length and mode specified
tl.PositionalEncoding(max_len=max_len, mode=mode)]