1.裁剪运算
- 常用于对tensor中的一些元素进行范围约束,约束范围有利于参数学习的稳定,防止过拟合;可以防止梯度离散;方便模型进行量化
import torch
a = torch.rand(2,2) * 10
print("a:\n",a)
b = a.clamp(2,5)
print("b:\n",b)
2.索引与数据筛选
- torch.where(condition,x,y)
- 按照条件(condition)从x和y中选出满足条件的元素组成新的tensor
import torch
a = torch.rand(4,4)
b = torch.rand(4,4)
print("a:\n",a)
print("b:\n",b)
out = torch.where(a>0.5,a,b)
print('out:\n',out)
- torch.index_select(input,dim,index)
- 按照指定索引输出tensor
- 参数dim:维度
- 参数index:要输出的维度数据的索引,按照给出索引顺序,输出新tensor
import torch
a = torch.rand(4,4)
print("a:\n",a)
out = torch.index_select(a,dim=0,index=torch.tensor([0,3,2]))
print('out:\n',out)
- torch.gather(input,dim,index)
- 在指定维度上按照索引复制输出tensor
- 参数dim:维度
- 参数index:选择数据的具体索引值
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
out = torch.gather(a,dim=0,index=torch.tensor([[0,1,1,1],[0,1,2,2],[0,1,3,3]]))
print('out:\n',out)
- torch.masked_select(input,mask)
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
mask = torch.gt(a,8)
print('mask:\n',mask)
out = torch.masked_select(a,mask)
print('out:\n',out)
- torch.take(input,indices)
- 将input看成一维tensor,按照indices的值为索引得到输出tensor
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
out = torch.take(a,index=torch.tensor([0,15,13,10]))
print('out:\n',out)
import torch
a = torch.tensor([[0,1,2,0],[2,3,0,1]])
print("a:\n",a)
out = torch.nonzero(a)
print('out:\n',out)
3.组合/拼接
- torch.cat(seq,dim):按照已经存在的维度进行拼接
import torch
a = torch.zeros((2,4))
b = torch.ones((2,4))
print("a:\n",a)
print("b:\n",b)
out = torch.cat((a,b),dim=1)
print('out:\n',out)
- torch.stack(seq,dim):按照新的维度进行拼接
import torch
a = torch.linspace(1,6,6).view(2,3)
b = torch.linspace(7,12,6).view(2,3)
print("a:\n",a)
print("b:\n",b)
out = torch.stack((a,b),dim=1)
print('out:\n',out)
4.切片
- torch.chunk(tensor,chunks,dim)
- 按照摸个维度平均分块,如果无法平均,最后一个维度值的个数会小于平均值
- 参数chunks:分块的个数
import torch
a = torch.rand((3,4))
print("a:\n",a)
out = torch.chunk(a,2,dim=0)
print('out:\n',out)
- torch.split(tensor,split_size_or_sections,dim)
- 按照某个维度,依照第二个参数给出的列表或数对tensor进行分割
- 参数split_size_or_sections:如果传入一个数值,则效果和torch.chunk相同;若是列表,则会按照列表中的值进行对应切分
import torch
a = torch.rand((10,4))
print("a:\n",a)
out = torch.split(a,[1,3,6],dim=0)
print('out:\n',out)
5.变形操作
- torch.reshape(input,shape)
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.reshape(a,(3,2))
print('out:\n',out)
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.t(a)
print('out:\n',out)
- torch.transpose(input,dim0,dim1)
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.transpose(a, 0, 1)
print('out:\n',out)
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.squeeze(a)
print('out:\n',out)
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.unbind(a,1)
print('out:\n',out)
- torch.unsqueeze(input,dim)
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.unsqueeze(a,-1)
print('out:\n',out)
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.flip(a,dims=[1,2])
print('out:\n',out)
- torch.rot90(input,k,dim)
- 按照指定维度(dim)和旋转次数(k)进行张量旋转
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.rot90(a,2)
print('out:\n',out)
6.填充
import torch
a = torch.full((2,3),10)
print("a:\n",a)