torch.cat在给定维度上对输入的张量序列进行连接操作

torch.cat(inputs, dimension=0) → Tensor
参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的 python 序列
dimension (int, optional) – 沿着此维连接张量序列

例子1  dim = 0:

import torch
import numpy
a = torch.randn(2, 2, 3) # 生成通道为2 的2行3列张量
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=0) # dim=0表示在通道上连接两个张量a,a,通道数翻倍,行列不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出通道数变化而行列没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]]])
torch.Size([4, 2, 3])
b:
 tensor([[[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]],

        [[ 0.4702, -1.1235,  1.6840],
         [ 0.8872,  1.1102, -1.8428]],

        [[ 0.8182, -0.5710, -1.2721],
         [ 0.2145,  1.2150,  0.7147]]])

例子2  dim = 1:

import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=1) # dim=1表示在行上连接两个张量a,a,行增加而通道数和列不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出行增加,而通道数和列没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735]],

        [[-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959]]])
torch.Size([2, 4, 3])
b:
 tensor([[[ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735],
         [ 0.7129, -0.1589, -1.4144],
         [ 1.2887,  2.2833,  0.7735]],

        [[-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959],
         [-1.8561,  0.2988, -0.3955],
         [-1.8440,  1.8290,  0.1959]]])

例子3  dim = 2:

import torch
import numpy
a = torch.randn(2, 2, 3)
print(a.data.size())
print(type(a))
print('a:\n', a)
b = torch.cat((a, a), dim=2) # dim=2表示在列上连接两个张量a,a,列增加,而行和通道数不变
print(b.data.size())
print('b:\n', b)

cat结果如下:可以看出列增加,而通道数和行没有任何变化!

torch.Size([2, 2, 3])
<class 'torch.Tensor'>
a:
 tensor([[[ 0.7935,  0.3900, -1.0024],
         [-0.2843, -0.5554,  0.4073]],

        [[ 1.0388, -0.4608,  0.4172],
         [ 0.6668, -1.2096, -1.2609]]])
torch.Size([2, 2, 6])
b:
 tensor([[[ 0.7935,  0.3900, -1.0024,  0.7935,  0.3900, -1.0024],
         [-0.2843, -0.5554,  0.4073, -0.2843, -0.5554,  0.4073]],

        [[ 1.0388, -0.4608,  0.4172,  1.0388, -0.4608,  0.4172],
         [ 0.6668, -1.2096, -1.2609,  0.6668, -1.2096, -1.2609]]])

<think>我们正在讨论的是PyTorch中的torch.cat函数,特别是dim参数的使用。根据引用内容,dim参数指定了拼接维度。需要注意的是,对于不同维度张量,dim的含义不同,不能简单地理解为行或列(尤其是在高维张量中)。关键点:1.torch.cat用于在给定维度拼接多个张量。2.dim参数指定拼接维度索引(从0开始)。3.所有张量除了在拼接维度上可以不同,其他维度必须相同。示例说明:引用[2]提供了一个二维张量的例子:X=[[1,2,3],[4,5,6]]Y=[[7,8,9],[1,4,7]]当dim=0时,沿着第0维(即行方向,垂直拼接):A=torch.cat([X,Y],dim=0)的结果是:[[1,2,3],[4,5,6],[7,8,9],[1,4,7]]当dim=1时,沿着第1维(即列方向,水平拼接):B=torch.cat([X,Y],dim=1)的结果是:[[1,2,3,7,8,9],[4,5,6,1,4,7]]但是,在深度学习中,我们经常处理更高维的张量,比如四维张量(batch,channel,height,width)。这时,dim参数的含义如下:-dim=0:沿着batch维度拼接(增加样本数)-dim=1:沿着channel维度拼接(增加通道数)-dim=2:沿着height维度拼接(增加高度)-dim=3:沿着width维度拼接(增加宽度)引用[4]强调了这一点:不能简单地以二维情况类推,而要根据实际维度来理解。下面,我们通过一个四维张量的例子来说明:假设有两个四维张量:tensor1:shape(2,3,32,32)#2个样本,3个通道,32x32图像tensor2:shape(3,3,32,32)#3个样本,3个通道,32x32图像1.沿dim=0拼接(batch维度):result=torch.cat([tensor1,tensor2],dim=0)结果形状:(2+3,3,32,32)=(5,3,32,32)2.沿dim=1拼接(channel维度):注意:此时要求两个张量在除了通道维度外的其他维度都相同,即batch_size、height、width必须相同。假设tensor1:(2,3,32,32),tensor2:(2,5,32,32)#相同batch_size,height,widthresult=torch.cat([tensor1,tensor2],dim=1)结果形状:(2,3+5,32,32)=(2,8,32,32)重要规则:除了拼接维度(dim指定的维度)可以不同,其他所有维度必须完全相同。错误示例:如果两个张量的形状在非拼接维度上不同,则会报错。例如,尝试在dim=1上拼接两个形状为(2,3,32,32)和(3,3,32,32)的张量,由于batch_size不同(2vs3),会抛出错误。总结步骤:1.确定要拼接张量列表(所有张量必须是相同维度的,比如都是4维)。2.选择一个维度(dim)进行拼接,确保其他维度的大小相同。3.调用torch.cat(tensors,dim=指定的维度)代码验证(四维张量):```pythonimporttorch#创建两个四维张量tensor_a=torch.randn(2,3,32,32)#2个样本,3通道,32x32tensor_b=torch.randn(2,2,32,32)#2个样本,2通道,32x32#沿通道维度拼接(dim=1)tensor_c=torch.cat([tensor_a,tensor_b],dim=1)print(tensor_c.shape)#输出:torch.Size([2,5,32,32])#如果尝试在dim=0上拼接,要求其他维度相同,但这里通道数不同(3vs2),所以不能拼接?实际上,在dim=0拼接时,其他维度(包括通道)必须相同。#下面这个操作会报错,因为除了dim=0,其他维度(这里是通道数)不相同:#tensor_d=torch.cat([tensor_a,tensor_b],dim=0)#报错#修正:创建两个其他维度完全相同的张量(包括通道数)tensor_a1=torch.randn(2,3,32,32)tensor_b1=torch.randn(3,3,32,32)#样本数不同,但通道数相同tensor_d=torch.cat([tensor_a1,tensor_b1],dim=0)print(tensor_d.shape)#输出:torch.Size([5,3,32,32])```注意:在dim=0上拼接时,要求其他维度(从dim=1到最后一维)都相同;在dim=1上拼接时,要求dim=0,2,3,...都相同,以此类推。因此,在使用torch.cat时,务必检查非拼接维度的尺寸是否一致。最后,再次强调:dim参数是维度的索引,从0开始。在拼接时,我们沿着该维度将多个张量连接起来,该维度的值会相加,而其他维度的值保持不变(且必须相同)。</think>### `torch.cat()` 中 `dim` 参数使用详解 `torch.cat()` 是 PyTorch 中用于拼接张量的核心函数,`dim` 参数指定拼接维度。**关键原则**:`dim` 值对应张量维度的索引(从0开始),拼接后该维度尺寸会增大,其他维度保持不变[^3][^4]。 --- #### 一、`dim` 参数核心规则 1. **维度索引规则** - `dim=0`:沿**第0维**(通常为 batch 维度拼接 - `dim=1`:沿**第1维**(通常为通道/特征维度拼接 - 更高维度以此类推 2. **拼接条件** 所有输入张量在**非拼接维度**上必须尺寸相同,否则报错[^3]。 --- #### 二、不同维度场景示例 ##### 1. 二维张量(矩阵) ```python import torch X = torch.tensor([[1,2,3], [4,5,6]]) # shape (2,3) Y = torch.tensor([[7,8,9], [1,4,7]]) # shape (2,3) # dim=0: 沿行拼接 (垂直拼接) A = torch.cat([X, Y], dim=0) # tensor([[1,2,3], # [4,5,6], # [7,8,9], # [1,4,7]]) # shape (4,3) # dim=1: 沿列拼接 (水平拼接) B = torch.cat([X, Y], dim=1) # tensor([[1,2,3,7,8,9], # [4,5,6,1,4,7]]) # shape (2,6) ``` ##### 2. 四维张量(CNN 特征图) ```python # 模拟CNN特征图: (batch, channel, height, width) t1 = torch.randn(2, 3, 32, 32) # 2张图, 3通道 t2 = torch.randn(3, 3, 32, 32) # 3张图, 3通道 # dim=0: 沿batch维度拼接 cat0 = torch.cat([t1, t2], dim=0) print(cat0.shape) # torch.Size([5, 3, 32, 32]) # batch=2+3=5 # dim=1: 沿通道维度拼接 t3 = torch.randn(2, 2, 32, 32) # 与t1 batch相同 cat1 = torch.cat([t1, t3], dim=1) print(cat1.shape) # torch.Size([2, 5, 32, 32]) # channel=3+2=5 ``` ##### 3. 三维张量序列数据) ```python # 序列数据: (batch, seq_len, features) seq1 = torch.randn(4, 10, 8) # 4个序列, 10步长, 8维特征 seq2 = torch.randn(4, 5, 8) # 4个序列, 5步长, 8维特征 # dim=1: 沿序列长度拼接 cat_seq = torch.cat([seq1, seq2], dim=1) print(cat_seq.shape) # torch.Size([4, 15, 8]) # seq_len=10+5=15 ``` --- #### 三、常见错误与解决 1. **维度不匹配错误** ```python # 错误示例: 非dim维尺寸不同 a = torch.rand(2, 3) b = torch.rand(2, 4) torch.cat([a, b], dim=0) # 报错! 第1维尺寸不同(3 vs 4) ``` 2. **错误理解维度含义** ```python # 危险误区: 认为dim=0总是"行" # 对于shape (C,H,W) 的图像: img1 = torch.rand(3, 256, 256) img2 = torch.rand(3, 256, 256) torch.cat([img1, img2], dim=0) # 结果shape (6,256,256) 通道数翻倍! ``` --- #### 四、最佳实践建议 1. **形状检查**:拼接前用 `.shape` 确认所有张量在非 `dim` 维度尺寸相同 2. **维度可视化**:高维数据可画出维度示意图: ``` 四维张量: [Batch, Channel, Height, Width] dim=0 → 增加Batch dim=1 → 增加Channel dim=2 → 增加Height dim=3 → 增加Width ``` 3. **使用 `torch.stack` 对比**:当需要**新增维度**时使用 `stack`,扩展维度时用 `cat`[^3] > **关键记忆点**:`dim` 值不是行列方向,而是**维度索引**,其物理意义取决于张量形状的实际含义[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lzdjlu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值