Useful and Cool torch.Tensor
Functions
PyTorch is a deep learning framework which simplifies the process of building, training, and deploying models. Data, which is stored in tensors, is a big part of training models. However, preparing data is also arguably the most tedious process. To make lives for deep learning tinkerers easier, a plethora of functions are built into PyTorch. Here are five underrated and cool torch.Tensor
functions (of many) that could potentially come in handy for you:
torch.index_select()
torch.masked_select()
torch.where()
torch.eye()
torch.unique()
# Import torch and other required modules
import torch
1) torch.index_select()
When used, torch.index_select()
allows you to pick multiple values, rows, or columns off of a tensor if you know the indices of them. This is especially useful if you need to pick multiple colums of a larger tensor while preserving its original shape (this will be specified below).
Parameters:
torch.Tensor
as inputint
dimension/axis to manipulatetorch.Tensor
containing indices to select (NOTE: indices must be passed as atorch.Tensor
)
# Example 1
X = torch.tensor([1, 2, 3, 4, 5])
indices = torch.tensor([0, 3])
torch.index_select(X, 0, indices)
tensor([1, 4])
Here, we specify to take index 0 and 3 from X
at the 0th axis, which are 1
and 4
, respectively.