前往我的个人博客,阅读体验更佳:本文链接
问题
刚接触神经网络,使用nn.Module
,看到下面代码:
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh() # Final GNN embedding space.
# Apply a final (linear) classifier.
out = self.classifier(h)
return out, h
model = GCN()
print(model)
data = dataset[0]
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
前面都好理解,主要是第32行,model是GCN类的一个实例,也可以理解。
但是第35行model(data.x, data.edge_index)
有点搞不明白了,为什么类的实例还能传参数?
查询了一些资料,发现Python通过一个特殊函数__call__()
让类实例也可以变成一个可调用对象。
__init__
class A:
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出:
分析:a=A()
进行了类的实例化,会自动调用__init__()
方法。
__call__
class A:
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出:
分析:a
是类A
的实例对象,a(1)
相当于调用了实例(不知道这么说对不对,意思就是实例对象也可以被调用,后面加括号传参数),会自动调用__call__()
方法。
__call__()
中可以调用其它函数,如forward
函数:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res
def forward(self, input):
print('forward 函数', input)
output = input + 1
return output
a = A()
b = a(1)
print('结果b =', b)
输出:
到这就有nn.Module
那味了,下面这个例子更接近文章开头展示的内容:
import torch
class A(torch.nn.Module):
def __init__(self):
super().__init__()
print('init函数')
def forward(self, input):
print('forward 函数', input)
output = input + 1
return output
a = A()
b = a(1)
print('结果b =', b)
输出:
这里并没有调用__call__()
(甚至我们都没有实现),还是调用了forward()
方法,原因是因为父类nn.Module
实现了__call__()
方法。
我们可以重写__call__()
方法,让其不调用forward
:
import torch.nn
class A(torch.nn.Module):
def __init__(self):
super().__init__()
print('init函数')
def forward(self, input):
print('forward 函数', input)
output = input + 1
return output
def __call__(self, input):
print('重写 call 函数', input)
return input
a = A()
b = a(1)
print('结果b =', b)
输出:
参考链接:
https://blog.csdn.net/qq_43745026/article/details/125537774