Learn practical skills, build real-world projects, and advance your career

Reformer Efficient Attention: Ungraded Lab

The videos describe two 'reforms' made to the Transformer to make it more memory and compute efficient. The Reversible Layers reduce memory and Locality Sensitive Hashing(LSH) reduces the cost of the Dot Product attention for large input sizes. This ungraded lab will look more closely at LSH and how it is used in the Reformer model.

Specifically, the notebook has 3 goals

  • review dot-product self attention for reference
  • examine LSH based self attention
  • extend our understanding and familiarity with Trax infrastructure

Outline

Part 1.0 Trax Efficient Attention classes

Trax is similar to other popular NN development platforms such as Keras (now integrated into Tensorflow) and Pytorch in that it uses 'layers' as a useful level of abstraction. Layers are often represented as classes. We're going to improve our understanding of Trax by locally extending the classes used in the attention layers. We will extend only the 'forward' functions and utilize the existing attention layers as parent classes. The original code can be found at github:trax/layers/Research/Efficient_attention. This link references release 1.3.4 but note that this is under the 'research' directory as this is an area of active research. When accessing the code on Github for review on this assignment, be sure you select the 1.3.4 release tag, the master copy may have new changes.:
alt

Figure 1: Reference Tag 1.3.4 on github

While Trax uses classes liberally, we have not built many classes in the course so far. Let's spend a few moments reviewing the classes we will be using.
alt

Figure 2: Classes from Trax/layers/Research/Efficient_Attention.py that we will be utilizing.

Starting on the right in the diagram below you see EfficientAttentionBase. The parent to this class is the base.layer which has the routines used by all layers. EfficientAttentionBase leaves many routines to be overridden by child classes - but it has an important feature in the Forward routine. It supports a use_reference_code capability that selects implementations that limit some of the complexities to provide a more easily understood version of the algorithms. In particular, it implements a nested loop that treats each 'example, head' independently. This simplifies our work as we need only worry about matrix operations on one 'example, head' at a time. This loop calls forward_unbatched, which is the child process that we will be overriding.

On the top left are the outlines of the two child classes we will be using. The SelfAttention layer is a 'traditional' implementation of the dot product attention. We will be implementing the forward_unbatched version of this to highlight the differences between this and the LSH implementation.

Below that is the LSHSelfAttention. This is the routine used in the Reformer architecture. We will override the forward_unbatched section of this and some of the utility functions it uses to explore its implementation in more detail.

The code we will be working with is from the Trax source, and as such has implementation details that will make it a bit harder to follow. However, it will allow use of the results along with the rest of the Trax infrastructure. I will try to briefly describe these as they arise. The Trax documentation can also be referenced.

Part 1.2 Trax Details

The goal in this notebook is to override a few routines in the Trax classes with our own versions. To maintain their functionality in a full Trax environment, many of the details we might ignore in example version of routines will be maintained in this code. Here are some of the considerations that may impact our code:

  • Trax operates with multiple back-end libraries, we will see special cases that will utilize unique features.
  • 'Fancy' numpy indexing is not supported in all backend environments and must be emulated in other ways.
  • Some operations don't have gradients for backprop and must be ignored or include forced re-evaluation.

Here are some of the functions we may see:

  • Abstracted as fastmath, Trax supports multiple backend's such as Jax and Tensorflow2
  • tie_in: Some non-numeric operations must be invoked during backpropagation. Normally, the gradient compute graph would determine invocation but these functions are not included. To force re-evaluation, they are 'tied' to other numeric operations using tie_in.
  • stop_gradient: Some operations are intentionally excluded from backprop gradient calculations by setting their gradients to zero.
  • Below we will execute from trax.fastmath import numpy as np , this uses accelerated forms of numpy functions. This is, however a subset of numpy
import os
import trax
from trax import layers as tl  # core building block
import jax
from trax import fastmath  # uses jax, offers numpy on steroids

# fastmath.use_backend('tensorflow-numpy')
import functools
from trax.fastmath import numpy as np  # note, using fastmath subset of numpy!
from trax.layers import (
    tie_in,
    length_normalized,
    apply_broadcasted_dropout,
    look_adjacent,
    permute_via_gather,
    permute_via_sort,
)