在 PyTorch 中,masked_fill
、masked_select
和 masked_scatter
是三种常用的掩码(mask)操作方法,它们通过布尔类型的掩码张量(mask
)对原始张量进行条件筛选或修改。以下是它们的详细解释和对比:
1. masked_fill
作用:将原始张量中 mask
为 True
的位置用指定值填充,其余位置保持不变。
参数:
• mask
(BoolTensor):与原始张量形状相同的布尔掩码。
• value
(标量):用于填充的值。
特点:
• 原地操作:会直接修改原始张量(除非使用 masked_fill_
的 in-place 版本)。
• 保持形状:输出张量形状与输入张量一致。
示例:
import torch
x = torch.tensor([[1, 2], [3, 4]])
mask = torch.tensor([[False, True]