取行列
import torch
x = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[0,0,0]])
y1 = x[:1,:] #行<=1,所有列
y2 = x[1:,:] #行>1,所有列,等价于x[1:]
y3 = x[:1,:2] #行<=1,列<=2
y4 = x[1:,2:] #行>1,列>2
print("原矩阵\n",x)
print("#行<=1,所有列\n",y1)
print("#行>1,所有列\n",y2)
print("#行<=1,列<=2\n",y3)
print("#行>1,列>2\n",y4)
结果
原矩阵
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[0, 0, 0]])
#行<=1,所有列
tensor([[1, 2, 3]])
#行>1,所有列 tensor([[4, 5, 6],
[7, 8, 9],
[0, 0, 0]])
#行<=1,列<=2
tensor([[1, 2]])
#行>1,列>2
tensor([[6],
[9],
[0]])