torch.split 函数测试

1. description

torch.split 的作用是将矩阵A按照指定的方式进行切割成不同大小的子矩阵。

  • excel
    在这里插入图片描述

2. pytorch

  • python
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    bs = 2
    seq_len = 3
    model_dim = 4
    ma_total = bs * seq_len * model_dim
    a_matrix = torch.arange(ma_total).reshape((bs, seq_len, model_dim)).to(torch.float)
    n_split = 2
    #a_matrix_dim0 = torch.split(a_matrix, n_split, dim=0)
    a_matrix_dim0 = torch.split(a_matrix, [1,1], dim=0)
    print(f"a_matrix=\n{a_matrix}")
    print(f"a_matrix_dim0=\n{a_matrix_dim0}")
    i = 0
    for bm in a_matrix_dim0:
        print(f"a_matrix_dim0[{i}]=\n{bm}")
        i += 1
    print(f"a_matrix=\n{a_matrix}")
    a_matrix_dim1 = torch.split(a_matrix, n_split, dim=1)
    print(f"a_matrix=\n{a_matrix}")
    print(f"a_matrix_dim1=\n{a_matrix_dim1}")
    i = 0
    for bm in a_matrix_dim1:
        print(f"a_matrix_dim1[{i}]=\n{bm}")
        i += 1
    print(f"a_matrix=\n{a_matrix}")
    a_matrix_dim2 = torch.split(a_matrix, n_split, dim=2)
    print(f"a_matrix=\n{a_matrix}")
    print(f"a_matrix_dim2=\n{a_matrix_dim2}")
    i = 0
    for bm in a_matrix_dim2:
        print(f"a_matrix_dim2[{i}]=\n{bm}")
        i += 1
  • result:
a_matrix=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix_dim0=
(tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]]), tensor([[[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]]))
a_matrix_dim0[0]=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]])
a_matrix_dim0[1]=
tensor([[[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix_dim1=
(tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.]]]), tensor([[[ 8.,  9., 10., 11.]],

        [[20., 21., 22., 23.]]]))
a_matrix_dim1[0]=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.]]])
a_matrix_dim1[1]=
tensor([[[ 8.,  9., 10., 11.]],

        [[20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
a_matrix_dim2=
(tensor([[[ 0.,  1.],
         [ 4.,  5.],
         [ 8.,  9.]],

        [[12., 13.],
         [16., 17.],
         [20., 21.]]]), tensor([[[ 2.,  3.],
         [ 6.,  7.],
         [10., 11.]],

        [[14., 15.],
         [18., 19.],
         [22., 23.]]]))
a_matrix_dim2[0]=
tensor([[[ 0.,  1.],
         [ 4.,  5.],
         [ 8.,  9.]],

        [[12., 13.],
         [16., 17.],
         [20., 21.]]])
a_matrix_dim2[1]=
tensor([[[ 2.,  3.],
         [ 6.,  7.],
         [10., 11.]],

        [[14., 15.],
         [18., 19.],
         [22., 23.]]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值