torch_sparse.SparseTensor函数索引

torch_sparse.SparseTensor函数通过row索引稀疏矩阵得到一个新的稀疏矩阵

from torch_sparse import SparseTensor
import torch

row = torch.tensor([0, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([2, 2, 0, 1, 3, 2, 4, 3])
val = torch.linspace(1, 8, 8)
c = SparseTensor(row=row, col=col, value=val)
print('c = {}'.format(c.to_dense()))
print(c)

d2 = c[row]    # 通过行索引c的特定行得到一个新的稀疏矩阵。
# 这个操作相当于下面的步骤:
# 将原密集张量的第 0 行取出构成新张量的第 0 行;
# 将原密集张量的第 1 行取出构成新张量的第 1 行;
# 将原密集张量的第 2 行取出构成新张量的第 2 行(第 2 行有三个元素);
# 将原密集张量的第 2 行取出构成新张量的第 3 行;
# 将原密集张量的第 2 行取出构成新张量的第 4 行;
# ...
print(d2)
d3 = c[torch.tensor([2])]
print(d3)

# c = tensor([[0., 0., 1., 0., 0.],
#         [0., 0., 2., 0., 0.],
#         [3., 4., 0., 5., 0.],
#         [0., 0., 6., 0., 7.],
#         [0., 0., 0., 8., 0.]])
# SparseTensor(row=tensor([0, 1, 2, 2, 2, 3, 3, 4]),
#              col=tensor([2, 2, 0, 1, 3, 2, 4, 3]),
#              val=tensor([1., 2., 3., 4., 5., 6., 7., 8.]),
#              size=(5, 5), nnz=8, density=32.00%)
# SparseTensor(row=tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 6, 7]),
#              col=tensor([2, 2, 0, 1, 3, 0, 1, 3, 0, 1, 3, 2, 4, 2, 4, 3]),
#              val=tensor([1., 2., 3., 4., 5., 3., 4., 5., 3., 4., 5., 6., 7., 6., 7., 8.]),
#              size=(8, 5), nnz=16, density=40.00%)
# SparseTensor(row=tensor([0, 0, 0]),
#              col=tensor([0, 1, 3]),
#              val=tensor([3., 4., 5.]),
#              size=(1, 5), nnz=3, density=60.00%)
### 获取 PyTorchSparse COO 张量的元素 在 PyTorch 中,Sparse COO (Coordinate Format) 张量存储非零元素及其对应的索引位置。要访问这些张量中的特定元素,可以通过其 `indices` 和 `values` 属性来实现。 对于给定的一个稀疏张量 `sp_adj`,可以先通过 `.indices()` 方法获得该张量中所有非零元素的位置,再利用 `.values()` 获得相应位置上的值[^4]。然而,直接获取某个具体坐标的单个元素并不像密集型张量那样直观,因为这涉及到查找操作。一种方法是构建一个掩码或者使用高级索引来匹配目标坐标并提取相应的值。 下面是一个简单的例子展示如何创建一个稀疏COO张量以及尝试从中检索特定位置的数据: ```python import torch # 创建一个具有重复索引的稀疏 COO 张量 sp_adj = torch.sparse_coo_tensor( indices=[[1, 1, 1, 2]], values=[2., 4., 6., 9.], size=(3,) ) print(f"Sparse Tensor:\n{sp_adj}") def get_sparse_element(sparse_tensor, index): """从稀疏 COO 张量中获取指定索引处的元素""" idx_mask = all(dim_idx == sparse_tensor.indices().squeeze() for dim_idx in enumerate(index)) matching_values = sparse_tensor.values()[idx_mask].sum() return matching_values.item() if matching_values.numel() > 0 else None index_to_find = [1] element_at_index = get_sparse_element(sp_adj, index_to_find) if element_at_index is not None: print(f"Element at index {index_to_find}: {element_at_index}") else: print(f"No single value found exactly at index {index_to_find}.") ``` 上述代码定义了一个辅助函数 `get_sparse_element` 来帮助定位和求和位于相同索引下的所有条目,并返回最终的结果作为所请求位置的实际数值表示形式。需要注意的是,在这个特殊情况下,由于存在多个相同的索引 `[1]`,因此结果将是这些位置上所有值之和。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值