PyTorch笔记4----------Tensor的操作

1.裁剪运算

  • 常用于对tensor中的一些元素进行范围约束,约束范围有利于参数学习的稳定,防止过拟合;可以防止梯度离散;方便模型进行量化
import torch
a = torch.rand(2,2) * 10
print("a:\n",a)
b = a.clamp(2,5)
print("b:\n",b)

2.索引与数据筛选

  • torch.where(condition,x,y)
    • 按照条件(condition)从x和y中选出满足条件的元素组成新的tensor
import torch
a = torch.rand(4,4)
b = torch.rand(4,4)
print("a:\n",a)
print("b:\n",b)
out = torch.where(a>0.5,a,b)
print('out:\n',out)
  • torch.index_select(input,dim,index)
    • 按照指定索引输出tensor
    • 参数dim:维度
    • 参数index:要输出的维度数据的索引,按照给出索引顺序,输出新tensor
import torch
a = torch.rand(4,4)
print("a:\n",a)
out = torch.index_select(a,dim=0,index=torch.tensor([0,3,2]))
print('out:\n',out)
  •  torch.gather(input,dim,index)
    • 在指定维度上按照索引复制输出tensor
    • 参数dim:维度
    • 参数index:选择数据的具体索引值
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
out = torch.gather(a,dim=0,index=torch.tensor([[0,1,1,1],[0,1,2,2],[0,1,3,3]]))
print('out:\n',out)
  • torch.masked_select(input,mask)
    • 按照mask输出tensor,输出为向量
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
mask = torch.gt(a,8)
print('mask:\n',mask)
out = torch.masked_select(a,mask)
print('out:\n',out)
  • torch.take(input,indices)
    • 将input看成一维tensor,按照indices的值为索引得到输出tensor
import torch
a = torch.linspace(1,16,16).view(4,4)
print("a:\n",a)
out = torch.take(a,index=torch.tensor([0,15,13,10]))
print('out:\n',out)
  • torch.nonzero(input)
    • 输出非0元素坐标
import torch
a = torch.tensor([[0,1,2,0],[2,3,0,1]])
print("a:\n",a)
out = torch.nonzero(a)
print('out:\n',out)

3.组合/拼接

  • torch.cat(seq,dim):按照已经存在的维度进行拼接
import torch
a = torch.zeros((2,4))
b = torch.ones((2,4))
print("a:\n",a)
print("b:\n",b)
out = torch.cat((a,b),dim=1)
print('out:\n',out)
  • torch.stack(seq,dim):按照新的维度进行拼接
import torch
a = torch.linspace(1,6,6).view(2,3)
b = torch.linspace(7,12,6).view(2,3)
print("a:\n",a)
print("b:\n",b)
out = torch.stack((a,b),dim=1)
print('out:\n',out)

4.切片 

  • torch.chunk(tensor,chunks,dim)
    • 按照摸个维度平均分块,如果无法平均,最后一个维度值的个数会小于平均值
    • 参数chunks:分块的个数
import torch
a = torch.rand((3,4))
print("a:\n",a)
out = torch.chunk(a,2,dim=0)
print('out:\n',out)
  • torch.split(tensor,split_size_or_sections,dim)
    • 按照某个维度,依照第二个参数给出的列表或数对tensor进行分割
    • 参数split_size_or_sections:如果传入一个数值,则效果和torch.chunk相同;若是列表,则会按照列表中的值进行对应切分
import torch
a = torch.rand((10,4))
print("a:\n",a)
out = torch.split(a,[1,3,6],dim=0)
print('out:\n',out)

5.变形操作

  • torch.reshape(input,shape)
    • 对输入的tensor的shape进行改变
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.reshape(a,(3,2))
print('out:\n',out)
  • torch.t(input) 
    • 对2维tensor进行转置
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.t(a)
print('out:\n',out)
  • torch.transpose(input,dim0,dim1)
    • 交换两个维度
import torch
a = torch.rand(2,3)
print("a:\n",a)
out = torch.transpose(a, 0, 1)
print('out:\n',out)
  • torch.squeeze(input,dim)
    • 去除维度大小为1的维度
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.squeeze(a)
print('out:\n',out)
  • torch.unbind(tensor,dim)
    • 去除某个维度
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.unbind(a,1)
print('out:\n',out)
  • torch.unsqueeze(input,dim)
    • 在指定位置添加维度
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.unsqueeze(a,-1)
print('out:\n',out)
  • torch.flip(input,dim)
    • 按照给定维度翻转张量
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.flip(a,dims=[1,2])
print('out:\n',out)
  • torch.rot90(input,k,dim)
    • 按照指定维度(dim)和旋转次数(k)进行张量旋转
import torch
a = torch.rand(1,2,3)
print("a:\n",a)
out = torch.rot90(a,2)
print('out:\n',out)

6.填充

  • torch.full()
import torch
a = torch.full((2,3),10)
print("a:\n",a)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值