1. torch.mul
—— 逐元素乘法(Element-wise Multiplication)
功能:
对两个 tensor 进行逐个元素的乘法操作。如果维度不同,需要满足广播机制(Broadcasting)的要求。
输入要求:
- 两个 tensor 的形状相同,或可以广播成相同形状。
- 支持任意维度(1D, 2D, 3D, nD)。
返回结果:
- 一个新的 tensor,形状与广播结果一致,所有元素是对应位置元素的乘积。
示例:
import torch
# 相同形状
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
print(torch.mul(a, b))
# 输出:
# tensor([[ 5, 12],
# [21, 32]])
# 广播情况
a = torch.tensor([[1], [2], [3]]) # shape: (3, 1)
b = torch.tensor([10, 20]) # shape: (2,)
print(torch.mul(a, b))
# 输出:
# tensor([[10, 20],
# [20, 40],
# [30, 60]])
常见应用:
- 损失函数计算,如平方差损失。
- 特征掩码、注意力加权等。
2. torch.mm
—— 二维矩阵乘法(Matrix Multiplication for 2D Tensors)
功能:
执行标准的二维矩阵乘法(矩阵点乘)。
输入要求:
- 两个 2D tensor。
- 第一个 tensor 的形状为
(m, n)
,第二个为(n, p)
。
返回结果:
- 形状为
(m, p)
的 tensor。
示例:
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
print(torch.mm(a, b))
# 输出:
# tensor([[19, 22],
# [43, 50]])
计算方式是经典的线性代数矩阵乘法:
[1*5 + 2*7, 1*6 + 2*8] = [19, 22]
[3*5 + 4*7, 3*6 + 4*8] = [43, 50]
常见应用:
- 神经网络中线性层、前向传播。
- 线性代数计算中的矩阵乘。
3. torch.bmm
—— 批量矩阵乘法(Batch Matrix Multiplication)
功能:
对一批(多个)二维矩阵进行并行矩阵乘法。
输入要求:
- 两个 3D tensor。
- 第一个 tensor 形状为
(b, m, n)
,第二个为(b, n, p)
,b
是 batch size。
返回结果:
- 一个形状为
(b, m, p)
的 tensor,包含b
个(m, p)
的结果。
示例:
a = torch.randn(5, 2, 3) # 5个 (2x3) 的矩阵
b = torch.randn(5, 3, 4) # 5个 (3x4) 的矩阵
result = torch.bmm(a, b)
print(result.shape) # 输出:torch.Size([5, 2, 4])
每个 batch 内部等价于 torch.mm(a[i], b[i])
。
常见应用:
- RNN 中批量处理时间序列。
- Transformer 中的批量注意力计算。
4. torch.matmul
—— 广义矩阵乘法(General Matrix Multiplication)
功能:
通用乘法接口,自动根据输入的维度来选择:
- 标量乘法(1D × 1D)
- 矩阵乘法(2D × 2D)
- 批次矩阵乘法(3D × 3D 或 nD × nD)
输入规则及行为:
A.shape | B.shape | 结果行为 |
---|---|---|
(n,) | (n,) | 点积,返回标量 |
(m, n) | (n,) | 矩阵 × 向量,返回 (m,) |
(m, n) | (n, p) | 矩阵 × 矩阵,返回 (m, p) |
(b, m, n) | (b, n, p) | 批量矩阵乘,返回 (b, m, p) |
广播 nD | 广播 nD | 广播后批量矩阵乘 |
示例 1:向量点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
print(torch.matmul(a, b)) # 输出:32
示例 2:矩阵 × 向量(向量与矩阵的每一个[i,:]向量相乘都是向量点积)
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([5, 6]) # shape: (2,)
print(torch.matmul(a, b)) # 输出: tensor([17, 39])
1*5+2*6=17
3*5+4*6=39
示例 3:批次矩阵乘法
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
print(torch.matmul(a, b).shape) # 输出: torch.Size([10, 3, 5])
常见应用:
- 推荐作为统一接口使用。
- 自动适配各类矩阵乘操作,不需区分 mm、bmm。
总结表格
函数名称 | 功能说明 | 输入维度 | 返回维度 | 特点 |
---|---|---|---|---|
torch.mul | 逐元素乘法 | 任意维度(支持广播) | 与输入形状一致 | element-wise,逐元素乘 |
torch.mm | 传统矩阵乘法(2D) | (m, n) × (n, p) | (m, p) | 只能用于二维矩阵乘 |
torch.bmm | 批量矩阵乘法(3D) | (b, m, n) × (b, n, p) | (b, m, p) | 只能用于三维批量操作 |
torch.matmul | 广义矩阵乘法(推荐) | 1D、2D、nD(自动判断) | 自动推断 | 自动广播,通用、最灵活 |
如你希望进一步对比实际运行时间、与 NumPy 对照、或在实际模型中如何使用,我也可以继续详细扩展。是否需要我再配合实际神经网络例子来说明?