Learn practical skills, build real-world projects, and advance your career

Useful and Cool torch.Tensor Functions

PyTorch is a deep learning framework which simplifies the process of building, training, and deploying models. Data, which is stored in tensors, is a big part of training models. However, preparing data is also arguably the most tedious process. To make lives for deep learning tinkerers easier, a plethora of functions are built into PyTorch. Here are five underrated and cool torch.Tensor functions (of many) that could potentially come in handy for you:

  • torch.index_select()
  • torch.masked_select()
  • torch.where()
  • torch.eye()
  • torch.unique()
# Import torch and other required modules
import torch

1) torch.index_select()

When used, torch.index_select() allows you to pick multiple values, rows, or columns off of a tensor if you know the indices of them. This is especially useful if you need to pick multiple colums of a larger tensor while preserving its original shape (this will be specified below).

Parameters:

  • torch.Tensor as input
  • int dimension/axis to manipulate
  • torch.Tensor containing indices to select (NOTE: indices must be passed as a torch.Tensor)
# Example 1 
X = torch.tensor([1, 2, 3, 4, 5])
indices = torch.tensor([0, 3])
torch.index_select(X, 0, indices)
tensor([1, 4])

Here, we specify to take index 0 and 3 from X at the 0th axis, which are 1 and 4, respectively.