Learn practical skills, build real-world projects, and advance your career

Putting the "Re" in Reformer: Ungraded Lab

This ungraded lab will explore Reversible Residual Networks. You will use these networks in this week's assignment that utilizes the Reformer model. It is based on on the Transformer model you already know, but with two unique features.

  • Locality Sensitive Hashing (LSH) Attention to reduce the compute cost of the dot product attention and
  • Reversible Residual Networks (RevNets) organization to reduce the storage requirements when doing backpropagation in training.

In this ungraded lab we'll start with a quick review of Residual Networks and their implementation in Trax. Then we will discuss the Revnet architecture and its use in Reformer.

Outline

import trax
from trax import layers as tl               # core building block
import numpy as np                          # regular ol' numpy
from trax.models.reformer.reformer import (
    ReversibleHalfResidualV2 as ReversibleHalfResidual,
)                                           # unique spot
from trax import fastmath                   # uses jax, offers numpy on steroids
from trax import shapes                     # data signatures: dimensionality and type
from trax.fastmath import numpy as jnp      # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0

Part 1.0 Residual Networks

Deep Residual Networks (Resnets) were introduced to improve convergence in deep networks. Residual Networks introduce a shortcut connection around one or more layers in a deep network as shown in the diagram below from the original paper.

alt
Figure 1: Residual Network diagram from original paper

The Trax documentation describes an implementation of Resnets using branch. We'll explore that here by implementing a simple resnet built from simple function based layers. Specifically, we'll build a 4 layer network based on two functions, 'F' and 'G'.

alt
Figure 2: 4 stage Residual network
Don't worry about the lengthy equations. Those are simply there to be referenced later in the notebook.

Part 1.1 Branch

Trax branch figures prominently in the residual network layer so we will first examine it. You can see from the figure above that we will need a function that will copy an input and send it down multiple paths. This is accomplished with a branch layer, one of the Trax 'combinators'. Branch is a combinator that applies a list of layers in parallel to copies of inputs. Lets try it out! First we will need some layers to play with. Let's build some from functions.

# simple function taking one input and one output
bl_add1 = tl.Fn("add1", lambda x0: (x0 + 1), n_out=1)
bl_add2 = tl.Fn("add2", lambda x0: (x0 + 2), n_out=1)
bl_add3 = tl.Fn("add3", lambda x0: (x0 + 3), n_out=1)
# try them out
x = np.array([1])
print(bl_add1(x), bl_add2(x), bl_add3(x))
# some information about our new layers
print(
    "name:",
    bl_add1.name,
    "number of inputs:",
    bl_add1.n_in,
    "number of outputs:",
    bl_add1.n_out,
)
[2] [3] [4] name: add1 number of inputs: 1 number of outputs: 1