squeeze()

squeeze 函数是 PyTorch 中的一个函数,用于从张量(Tensor)中去除所有长度为 1 的维度。这在处理神经网络模型的输出时非常有用,因为模型的输出可能包含一些不必要的单一维度,例如在批处理大小为 1 的情况下,输出可能会有一个额外的批次维度。

简单来说就是从张量(Tensor)中去除所有长度为 1 的维度。

实例:

import torch

x = torch.randn(1, 2, 1, 2, 2)
print(x.shape)  # 输出: torch.Size([1, 2, 1, 2, 2])

y = torch.squeeze(x)
print(y.shape)  # 输出: torch.Size([2, 2, 2])

z = torch.squeeze(x, dim=1)
print(z.shape)  # 输出: torch.Size([1, 2, 1, 2, 2])

v= torch.squeeze(x, dim=0)
print(v.shape)  # 输出: torch.Size([1, 2, 2, 2])

输出:

#x:

torch.Size([1, 2, 1, 2, 2])

#y:不指定维度就去除所有长度为 1 的维度
torch.Size([2, 2, 2])

#z:指定维度长度不为1则无效
torch.Size([1, 2, 1, 2, 2])

#v:指定维度长度为1生效
torch.Size([2, 1, 2, 2])

03-17
### NumPy 中 `squeeze` 方法的使用说明 `squeeze` 是一种用于移除数组中大小为 1 的维度的操作,在多个库中有不同的实现方式,但在核心功能上保持一致。以下是关于 `squeeze` 操作的具体解释: #### 基本概念 在多维数组或张量的数据结构中,有时会存在一些多余的单维度(即形状为 1 的维度)。这些维度可能是在数据预处理过程中引入的,或者是为了适配某些算法而保留的。然而,在实际计算时,它们可能会增加不必要的复杂度。因此,`squeeze` 方法被设计用来删除这些冗余的维度。 对于 NumPy 数组而言,`squeeze` 可以接受一个可选参数 `axis` 来指定要移除的轴[^1]。如果未提供此参数,则默认移除所有尺寸为 1 的轴。 #### 参数详解 - **无参调用**: 当不传递任何额外参数给 `np.squeeze()` 函数时,它将自动检测并消除输入数组里所有的单位长度维度。 ```python import numpy as np arr = np.array([[[0], [1], [2]]]) result = np.squeeze(arr) print(result) # 输出: [0 1 2] print(result.shape) # 输出: (3,) ``` - **带 axis 参数调用**: 用户可以通过设置 `axis` 参数来控制具体哪些轴上的单位长度会被去除。这允许更精确地调整输出形式而不影响其他部分。 ```python arr_reshaped = np.arange(9).reshape((1,3,3)) squeezed_along_axis_zero = np.squeeze(arr_reshaped,axis=0) print(squeezed_along_axis_zero) # 输出: [[0 1 2][3 4 5][6 7 8]] print(squeezed_along_axis_zero.shape) # 输出: (3, 3) ``` 上述例子展示了当指定了特定方向之后的结果变化情况。 #### PyTorch 下的应用实例 除了 NumPy 外,PyTorch 库也提供了类似的 `squeeze` 功能用于简化张量结构。下面是一个简单的演示案例: ```python import torch tensor_example = torch.rand((1,4)) print(tensor_example.size()) # 输出: torch.Size([1, 4]) squeezed_tensor = tensor_example.squeeze() print(squeezed_tensor.size()) # 输出: torch.Size([4]) ``` 这里可以看到原始的一行四列二维矩阵经过挤压变成了单纯的一维向量[^2]。 #### Pandas DataFrame 上的表现差异 值得注意的是,在 Pandas 数据帧对象上调用 `.squeeze()` 并不是为了改变其几何形态而是尝试将其转换成 Series 类型的对象。此时可以利用附加选项如 'index' 或者 'columns' 进一步明确意图[^4]: ```python df_single_row = pd.DataFrame([[1]], columns=['A']) series_converted = df_single_row.squeeze() type(series_converted).__name__ # 返回值应为'Series' ``` 以上代码片段表明即使是从仅含有一条记录的小表格出发也能顺利过渡到序列状态下去继续开展后续工作流程。 --- ### 总结 无论在哪种框架下运用,“挤掉”那些多余却不起作用的空间总是有助于提升程序运行效率以及减少内存占用率等问题的发生几率。掌握好这一技巧无疑会对日常编程实践带来极大便利!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值