Learn practical skills, build real-world projects, and advance your career

Useful functions to reshape PyTorch tensors

Reshape functions

PyTorch is an open source machine learning framework that accelerates the path from research prototyping to production deployment. This is possible thanks to its powerful resources. Here I explain some functions to reshape tensors.

  • size()
  • len()
  • numel()
  • reshape()
  • flatten()
# Import torch and other required modules
import torch

Get the tensor size - torch.size

We can get the tensor size as shown:

# Example 1 -
# Supose we have the following tensor:
t = torch.tensor([
    [1,1,1,1],
    [2,2,2,2],
    [3,3,3,3]
], dtype=torch.float32)

t.size()
torch.Size([3, 4])

We have a tensor with 3 rows and 4 columns