一、张量的操作:拼接、切分、索引和变换
(1)张量拼接与切分
1.1 torch.cat()
功能:将张量按维度dim进行拼接
• tensors: 张量序列
• dim : 要拼接的维度
torch.cat(tensors, dim=0, out=None)
函数用于沿着指定维度dim
将多个张量拼接在一起。它接受一个张量列表tensors
作为输入,并返回一个拼接后的张量。参数dim
指定了拼接的维度,默认为0。
flag = True
if flag:
t = torch.ones((2, 3)) #创建一个形状为(2, 3)的全为1的张量t
t_0 = torch.cat([t, t], dim=0) # 将张量t沿着dim=0的维度拼接两次,得到张量t_0
t_1 = torch.cat([t, t, t], dim=1)# 将张量t沿着dim=1的维度拼接三次,得到张量t_1
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
t_0:tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])
- format()用于将t_0、t_0.shape、t_1和t_1.shape的值插入到字符串中的占位符位置。具体来说,{}是占位符,format()方法将会将对应的值填充到这些占位符的位置。
- 例如,“t_0:{} shape:{}\nt_1:{} shape:{}”.format(t_0, t_0.shape, t_1, t_1.shape)中的{}将会分别被t_0、t_0.shape、t_1和t_1.shape的值所替代。
1.2 torch.stack()
功能:在新创建的维度dim上进行拼接
• tensors:张量序列
• dim :要拼接的维度
torch.stack(tensors, dim=0, out=None)
函数用于沿着指定维度dim
将多个张量堆叠在一起。它接受一个张量列表tensors
作为输入,并返回一个堆叠后的张量。参数dim
指定了堆叠的维度,默认为0。
以下是示例代码:
import torch
# 示例张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])
# 使用torch.cat()函数拼接张量
cat_result = torch.cat([tensor1, tensor2, tensor3], dim=0)
print("torch.cat()拼接结果:", cat_result)
# 使用torch.stack()函数堆叠张量
stack_result = torch.stack([tensor1, tensor2, tensor3], dim=0)
print("torch.stack()堆叠结果:", stack_result)
输出结果为:
torch.cat()拼接结果: tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.stack()堆叠结果: tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
1.3 torch.chunk()
torch.chunk()
是PyTorch中的一个函数,用于将张量沿着指定维度进行分块。
函数的参数包括:
input
:要分块的输入张量。chunks
:要分成的块数。dim
(可选):指定要分块的维度,默认为0。
函数返回一个元组,包含分块后的张量块。
以下是一个示例代码:
import torch
x = torch.tensor([