如何使用Pytorch Geometric的InMemoryDataset方法构建自己的图数据集
时间: 2024-12-19 11:26:57 浏览: 119
要使用PyTorch Geometric的`InMemoryDataset`方法构建自己的图数据集,首先需要了解几个关键概念:
1. `Data`类:这是PyTorch Geometric的核心,包含了图的各种属性,如节点特征(x)、边连接信息(edge_index)以及可能的标签(y)。
2. `InMemoryDataset`:这是一个抽象基类,用于存储预先加载到内存的数据。它需要子类实现两个方法:`__init__`来初始化数据结构,以及`getitem`来按索引访问数据。
下面是构建一个简单的图数据集的步骤:
- **定义数据元素**:
- `x` 是节点特征矩阵。
- `edge_index` 是边的连接矩阵,通常是一个二维张量,每个行代表一条边,第一列是起点,第二列是终点。
- 可选地,还可以有`y` 表示节点或整个图的标签。
```python
from torch_geometric.data import InMemoryDataset, Data
class CustomGraphDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(CustomGraphDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = self.load_data()
def load_data(self):
# 使用示例数据填充Data实例
x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=torch.float) # 节点特征
y = torch.tensor([0, 1, 0, 1], dtype=torch.float) # 标签 (可选)
edge_index = torch.tensor([[0, 2, 1, 0, 3], [3, 1, 0, 1, 2]], dtype=torch.long) # 边连接
data_list = [Data(x=x[i], y=y[i] if y is not None else None, edge_index=edge_index[i]) for i in range(len(x))]
return data_list, None # 假设数据不需要切分
# 创建并使用数据集
dataset = CustomGraphDataset(root='./data')
graph = dataset[0] # 获取第一个样本
```
- **使用预处理**(可选):如果需要在加载数据之前执行一些转换,可以在`pre_transform`方法中定义。
创建电商购买图数据集的示例:
```python
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) # 无向边表示
x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # 节点特征
custom_edge_dataset = CustomGraphDataset(
root='./data',
edge_index=edge_index,
x=x,
)
# 访问一个样本,假设无标签
buying_graph = custom_edge_dataset[0]
```
阅读全文
相关推荐


















