Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
Useful functions to reshape PyTorch tensors
Reshape functions
PyTorch is an open source machine learning framework that accelerates the path from research prototyping to production deployment. This is possible thanks to its powerful resources. Here I explain some functions to reshape tensors.
- size()
- len()
- numel()
- reshape()
- flatten()
# Import torch and other required modules
import torch
Get the tensor size - torch.size
We can get the tensor size as shown:
# Example 1 -
# Supose we have the following tensor:
t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
], dtype=torch.float32)
t.size()
torch.Size([3, 4])
We have a tensor with 3 rows and 4 columns