torch_sparse.SparseTensor函数通过row索引稀疏矩阵得到一个新的稀疏矩阵
from torch_sparse import SparseTensor
import torch
row = torch.tensor([0, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([2, 2, 0, 1, 3, 2, 4, 3])
val = torch.linspace(1, 8, 8)
c = SparseTensor(row=row, col=col, value=val)
print('c = {}'.format(c.to_dense()))
print(c)
d2 = c[row] # 通过行索引c的特定行得到一个新的稀疏矩阵。
# 这个操作相当于下面的步骤:
# 将原密集张量的第 0 行取出构成新张量的第 0 行;
# 将原密集张量的第 1 行取出构成新张量的第 1 行;
# 将原密集张量的第 2 行取出构成新张量的第 2 行(第 2 行有三个元素);
# 将原密集张量的第 2 行取出构成新张量的第 3 行;
# 将原密集张量的第 2 行取出构成新张量的第 4 行;
# ...
print(d2)
d3 = c[torch.tensor([2])]
print(d3)
# c = tensor([[0., 0., 1., 0., 0.],
# [0., 0., 2., 0., 0.],
# [3., 4., 0., 5., 0.],
# [0., 0., 6., 0., 7.],
# [0., 0., 0., 8., 0.]])
# SparseTensor(row=tensor([0, 1, 2, 2, 2, 3, 3, 4]),
# col=tensor([2, 2, 0, 1, 3, 2, 4, 3]),
# val=tensor([1., 2., 3., 4., 5., 6., 7., 8.]),
# size=(5, 5), nnz=8, density=32.00%)
# SparseTensor(row=tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 6, 7]),
# col=tensor([2, 2, 0, 1, 3, 0, 1, 3, 0, 1, 3, 2, 4, 2, 4, 3]),
# val=tensor([1., 2., 3., 4., 5., 3., 4., 5., 3., 4., 5., 6., 7., 6., 7., 8.]),
# size=(8, 5), nnz=16, density=40.00%)
# SparseTensor(row=tensor([0, 0, 0]),
# col=tensor([0, 1, 3]),
# val=tensor([3., 4., 5.]),
# size=(1, 5), nnz=3, density=60.00%)