Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
Five tensor functions one should know
Lesson 01 - Assignment
PyTorch is a python based library for machine learning. torch.tensor is the fundamental data structure to store multi-dimensional tensors. A tensor is a multi-dimensional matrix that contains elements of single data type. tensor has many in-built functions that are useful for basic operations. Some of the functions are shown in this notebook.
- torch.take
- torch.matrix_power
- torch.cat
- torch.where
- torch.narrow
# Import torch and other required modules
import torch
Function 1 -torch.take
Returns a new tensor with the elements of input at the given indices.
# Example 1 - working
# initialising a tensor
a = torch.tensor([[0, 7, 5],
[4, 5, 2]])
print(a)
# Taking elements in the indices 1,3 and 0 respectively from a
torch.take(a, torch.tensor([1, 3, 0]))
tensor([[0, 7, 5],
[4, 5, 2]])
tensor([7, 4, 0])
We take elements in the indices 1,3 and 0 respectively from tensor a