1. description
torch.split 的作用是将矩阵A按照指定的方式进行切割成不同大小的子矩阵。
- excel
2. pytorch
- python
import torch
import torch.nn as nn
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
bs = 2
seq_len = 3
model_dim = 4
ma_total = bs * seq_len * model_dim
a_matrix = torch.arange(ma_total).reshape((bs, seq_len, model_dim)).to(torch.float)
n_split = 2
#a_matrix_dim0 = torch.split(a_matrix, n_split, dim=0)
a_matrix_dim0 = torch.split(a_matrix, [1,1], dim=0)
print(f"a_matrix=\n{a_matrix}")
print(f"a_matrix_dim0=\n{a_matrix_dim0}")
i = 0
for bm in a_matrix_dim0:
print(f"a_matrix_dim0[{i}]=\n{bm}")
i += 1
print(f"a_matrix=\n{a_matrix}")
a_matrix_dim1 = torch.split(a_matrix, n_split, dim=1)
print(f"a_matrix=\n{a_matrix}")
print(f"a_matrix_dim1=\n{a_matrix_dim1}")
i = 0
for bm in a_matrix_dim1:
print(f"a_matrix_dim1[{i}]=\n{bm}")
i += 1
print(f"a_matrix=\n{a_matrix}")
a_matrix_dim2 = torch.split(a_matrix, n_split, dim=2)
print(f"a_matrix=\n{a_matrix}")
print(f"a_matrix_dim2=\n{a_matrix_dim2}")
i = 0
for bm in a_matrix_dim2:
print(f"a_matrix_dim2[{i}]=\n{bm}")
i += 1
- result:
a_matrix=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix_dim0=
(tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]]]), tensor([[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]]))
a_matrix_dim0[0]=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]]])
a_matrix_dim0[1]=
tensor([[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix_dim1=
(tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.]]]), tensor([[[ 8., 9., 10., 11.]],
[[20., 21., 22., 23.]]]))
a_matrix_dim1[0]=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.]]])
a_matrix_dim1[1]=
tensor([[[ 8., 9., 10., 11.]],
[[20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
a_matrix_dim2=
(tensor([[[ 0., 1.],
[ 4., 5.],
[ 8., 9.]],
[[12., 13.],
[16., 17.],
[20., 21.]]]), tensor([[[ 2., 3.],
[ 6., 7.],
[10., 11.]],
[[14., 15.],
[18., 19.],
[22., 23.]]]))
a_matrix_dim2[0]=
tensor([[[ 0., 1.],
[ 4., 5.],
[ 8., 9.]],
[[12., 13.],
[16., 17.],
[20., 21.]]])
a_matrix_dim2[1]=
tensor([[[ 2., 3.],
[ 6., 7.],
[10., 11.]],
[[14., 15.],
[18., 19.],
[22., 23.]]])