Python将实例对象作为方法使用

前往我的个人博客,阅读体验更佳:本文链接

问题

刚接触神经网络,使用nn.Module,看到下面代码:

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub

dataset = KarateClub()


class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.

        # Apply a final (linear) classifier.
        out = self.classifier(h)

        return out, h


model = GCN()
print(model)
data = dataset[0]
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

前面都好理解,主要是第32行,model是GCN类的一个实例,也可以理解。

但是第35行model(data.x, data.edge_index)有点搞不明白了,为什么类的实例还能传参数?

查询了一些资料,发现Python通过一个特殊函数__call__()让类实例也可以变成一个可调用对象。

__init__

class A:
    def __init__(self):
        print('init函数')

    def __call__(self, param):
        print('call 函数', param)


a = A()

输出:

在这里插入图片描述

分析:a=A()进行了类的实例化,会自动调用__init__()方法。

__call__

class A:
    def __init__(self):
        print('init函数')

    def __call__(self, param):
        print('call 函数', param)


a = A()
a(1)

输出:

在这里插入图片描述

分析:a是类A的实例对象,a(1)相当于调用了实例(不知道这么说对不对,意思就是实例对象也可以被调用,后面加括号传参数),会自动调用__call__()方法。

__call__()中可以调用其它函数,如forward函数:

class A():
    def __init__(self):
        print('init函数')

    def __call__(self, param):
        print('call 函数', param)
        res = self.forward(param)
        return res

    def forward(self, input):
        print('forward 函数', input)
        output = input + 1
        return output


a = A()
b = a(1)
print('结果b =', b)

输出:

在这里插入图片描述

到这就有nn.Module那味了,下面这个例子更接近文章开头展示的内容:

import torch


class A(torch.nn.Module):
    def __init__(self):
        super().__init__()
        print('init函数')

    def forward(self, input):
        print('forward 函数', input)
        output = input + 1
        return output


a = A()
b = a(1)
print('结果b =', b)

输出:

在这里插入图片描述

这里并没有调用__call__()(甚至我们都没有实现),还是调用了forward()方法,原因是因为父类nn.Module实现了__call__()方法。

我们可以重写__call__()方法,让其不调用forward

import torch.nn


class A(torch.nn.Module):
    def __init__(self):
        super().__init__()
        print('init函数')

    def forward(self, input):
        print('forward 函数', input)
        output = input + 1
        return output

    def __call__(self, input):
        print('重写 call 函数', input)
        return input


a = A()
b = a(1)
print('结果b =', b)

输出:

在这里插入图片描述

参考链接:
https://blog.csdn.net/qq_43745026/article/details/125537774

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值