Learn practical skills, build real-world projects, and advance your career

In course 1, you implemented Logistic regression and Naive Bayes for sentiment analysis. However if you were to give your old models an example like:

This movie was almost good.

Your model would have predicted a positive sentiment for that review. However, that sentence has a negative sentiment and indicates that the movie was not good. To solve those kinds of misclassifications, you will write a program that uses deep neural networks to identify sentiment in text. By completing this assignment, you will:

  • Understand how you can build/design a model using layers
  • Train a model using a training loop
  • Use a binary cross-entropy loss function
  • Compute the accuracy of your model
  • Predict using your own input

As you can tell, this model follows a similar structure to the one you previously implemented in the second course of this specialization.

  • Indeed most of the deep nets you will be implementing will have a similar structure. The only thing that changes is the model architecture, the inputs, and the outputs. Before starting the assignment, we will introduce you to the Google library trax that we use for building and training models.

Now we will show you how to compute the gradient of a certain function f by just using .grad(f).

  • Trax source code can be found on Github: Trax
  • The Trax code also uses the JAX library: JAX

Part 1: Import libraries and try out Trax

  • Let's import libraries and look at an example of using the Trax library.
import os 
import random as rnd

# import relevant libraries
import trax

# set random seeds to make this notebook easier to replicate
trax.supervised.trainer_lib.init_random_number_generators(31)

# import trax.fastmath.numpy
import trax.fastmath.numpy as np

# import trax.layers
from trax import layers as tl

# import Layer from the utils.py file
from utils import Layer, load_tweets, process_tweet
#from utils import 
INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0
[nltk_data] Downloading package twitter_samples to [nltk_data] /home/jovyan/nltk_data... [nltk_data] Package twitter_samples is already up-to-date! [nltk_data] Downloading package stopwords to /home/jovyan/nltk_data... [nltk_data] Package stopwords is already up-to-date!
# Create an array using trax.fastmath.numpy
a = np.array(5.0)

# View the returned array
display(a)

print(type(a))
DeviceArray(5., dtype=float32)
<class 'jax.interpreters.xla.DeviceArray'>