np.where
numpy.where() 有两种用法:
- numpy.where(condition, x, y)
满足condition, 输出x,不满足输出y
condition = [[True, False], [True, True]]
x = [[1, 2], [3, 4]]
y = [[9, 8], [7, 6]]
np.where(condition, x, y)
>>>
[[1 8]
[3 4]]
- numy.where(condition)
只有condiiton,输出满足condition的元素的下标, 这里的坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标
x = np.arange(12).reshape((2, 3, 2))
z = np.where(x > 7)
>>>
[[[ 0 1]
[ 2 3]
[ 4 5]]
[[ 6 7]
[ 8 9]
[10 11]]]
(array([1, 1, 1, 1]), array([1, 1, 2, 2]), array([0, 1, 0, 1]))
torch.where
torch.where()函数的作用是按照一定的规则合并两个tensor类型
import torch
a = torch.randn(2, 3)
b = torch.ones_like(a)
com = torch.where(a > 0, a, b)
>>>
tensor([[-0.9728, 1.2321, -0.1471],
[ 0.3736, 0.5832, 0.5332]])
tensor([[1.0000, 1.2321, 1.0000],
[0.3736, 0.5832, 0.5332]])