Learn practical skills, build real-world projects, and advance your career

Working with JAX numpy and calculating perplexity: Ungraded Lecture Notebook

Normally you would import numpy and rename it as np.

However in this week's assignment you will notice that this convention has been changed.

Now standard numpy is not renamed and trax.fastmath.numpy is renamed as np.

The rationale behind this change is that you will be using Trax's numpy (which is compatible with JAX) far more often. Trax's numpy supports most of the same functions as the regular numpy so the change won't be noticeable in most cases.

import numpy
import trax
import trax.fastmath.numpy as np

# Setting random seeds
trax.supervised.trainer_lib.init_random_number_generators(32)
numpy.random.seed(32)
INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0

One important change to take into consideration is that the types of the resulting objects will be different depending on the version of numpy. With regular numpy you get numpy.ndarray but with Trax's numpy you will get jax.interpreters.xla.DeviceArray. These two types map to each other. So if you find some error logs mentioning DeviceArray type, don't worry about it, treat it like you would treat an ndarray and march ahead.

You can get a randomized numpy array by using the numpy.random.random() function.

This is one of the functionalities that Trax's numpy does not currently support in the same way as the regular numpy.

numpy_array = numpy.random.random((5,10))
print(f"The regular numpy array looks like this:\n\n {numpy_array}\n")
print(f"It is of type: {type(numpy_array)}")
The regular numpy array looks like this: [[0.85888927 0.37271115 0.55512878 0.95565655 0.7366696 0.81620514 0.10108656 0.92848807 0.60910917 0.59655344] [0.09178413 0.34518624 0.66275252 0.44171349 0.55148779 0.70371249 0.58940123 0.04993276 0.56179184 0.76635847] [0.91090833 0.09290995 0.90252139 0.46096041 0.45201847 0.99942549 0.16242374 0.70937058 0.16062408 0.81077677] [0.03514717 0.53488673 0.16650012 0.30841038 0.04506241 0.23857613 0.67483453 0.78238275 0.69520163 0.32895445] [0.49403187 0.52412136 0.29854125 0.46310814 0.98478429 0.50113492 0.39807245 0.72790532 0.86333097 0.02616954]] It is of type: <class 'numpy.ndarray'>