Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric

本文介绍了图神经网络(GNN)核心组件Message Passing,通过实例讲解了SageConv层的构造,包括消息传递函数、聚合函数和更新函数。通过torch.nn模块演示了如何使用线性变换和激活函数来构建节点特征的交互过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MessagePassing

Message passing is the essence of GNN which describes how node embeddings are learned. I have talked about in my last post, so I will just briefly run through this with terms that conform to the PyG documentation.
在这里插入图片描述
x denotes the node embeddings, e denotes the edge features, 𝜙 denotes the message function, □ denotes the aggregation function, 𝛾 denotes the update function. If the edges in the graph have no feature other than connectivity, e is essentially the edge index of the graph. The superscript represents the index of the layer. When k=1, x represents the input feature of each node. Below I will illustrate how each function works:

  • propagate(edge_index, size=None, **kwargs):
    It takes in edge index and other optional information, such as node features (embedding). Calling this function will consequently call message and update.

  • message(**kwargs):
    You specify how you construct “message” for each of the node pair (x_i, x_j). Since it follows the calls of propagate, it can take any argument passing to propagate. One thing to note is that you can define the mapping from arguments to the specific nodes with “_i” and “_j”. Therefore, you must be very careful when naming the argument of this function.

  • update(aggr_out, **kwargs)
    It takes in the aggregated message and other arguments passed into propagate, assigning a new embedding value for each node.

Example

Let’s see how we can implement a SageConv layer from the paper “Inductive Representation Learning on Large Graphs”. The message passing formula of SageConv is defined as:
在这里插入图片描述https://arxiv.org/abs/1706.02216

Here, we use max pooling as the aggregation method. Therefore, the right-hand side of the first line can be written as:
在这里插入图片描述
https://arxiv.org/abs/1706.02216

which illustrates how the “message” is constructed. Each neighboring node embedding is multiplied by a weight matrix, added a bias and passed through an activation function. This can be easily done with torch.nn.Linear.

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max')
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        
    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
      
        return x_j

As for the update part, the aggregated message and the current node embedding is aggregated. Then, it is multiplied by another weight matrix and applied another activation function.

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max')
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        
        new_embedding = torch.cat([aggr_out, x], dim=1)
        new_embedding = self.update_lin(new_embedding)
        new_embedding = torch.update_act(new_embedding)
        
        return new_embedding

Putting it together, we have the following SageConv layer.


import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值