Learn practical skills, build real-world projects, and advance your career

Five tensor functions one should know

Lesson 01 - Assignment

PyTorch is a python based library for machine learning. torch.tensor is the fundamental data structure to store multi-dimensional tensors. A tensor is a multi-dimensional matrix that contains elements of single data type. tensor has many in-built functions that are useful for basic operations. Some of the functions are shown in this notebook.

  • torch.take
  • torch.matrix_power
  • torch.cat
  • torch.where
  • torch.narrow
# Import torch and other required modules
import torch

Function 1 -torch.take

Returns a new tensor with the elements of input at the given indices.

# Example 1 - working

# initialising a tensor
a = torch.tensor([[0, 7, 5],
                        [4, 5, 2]])
print(a)
# Taking elements in the indices 1,3 and 0 respectively from a
torch.take(a, torch.tensor([1, 3, 0]))
tensor([[0, 7, 5], [4, 5, 2]])
tensor([7, 4, 0])

We take elements in the indices 1,3 and 0 respectively from tensor a