目录
在Pytorch做深度学习过程中,CNN的卷积和池化过程中会用到torch.cat、squeeze()、unsqueeze()和size()函数,下面分别做讲解:
一、sequeeze()函数
x.squeeze(dim)
用途:进行维度压缩,去掉tensor中维数为1的维度
参数设置:如果设置dim=a,就是去掉指定维度中维数为1的
示例:
import torch
x = torch.tensor([[[1],[2]],[[3],[4]]])
print('x:',x)
print(x.shape)
x1 = x.squeeze()
print('x1:',x1)
print(x1.shape)
x2 = x.squeeze(2)
print('x2:',x2)
print(x2.shape)
输出:
x: tensor([[[1],
[2]],
[[3],
[4]]])
torch.Size([2, 2, 1]