dgl分批次进行链接预测任务时,会用到dataloader类,可以参考一下代码定义自己的dataloader。
dataloader = dgl.dataloading.DistEdgeDataLoader(
# dataloader = dgl.dataloading.EdgeDataLoader()
graph, train_seeds, sampler,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),
# negative_sampler=NegativeSampler(graph,args.negsample),
batch_size=4,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=False)
在设置自己的负采样方法时,总是存在问题,详细看了这里的负采样方法negative_sampler,作为记录。通过一个小例子理解一下该方法的使用:
g = dgl.graph(([0, 1, 2], [1, 2, 3]))
neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(4)
neg_sampler(g, torch.tensor([0]))
>>>(tensor([0, 0, 0, 0]), tensor([3, 0, 2, 1]))
由上述示例可以看出,neg_sampler的输出结果是一个<class ‘tuple’>,头部为源节点,尾部为目标节点。dgl.dataloading.negative_sampler.PerSourceUniform(4)中的参数N(例子中为4),是负采样的个数,neg_sampler使用时,需要输入图(g)还有选择的源节点(示例中为:torch.tensor([0]))。
主要了解到neg_sampler的结果是由两个tensor组成的tuple,所以自己在定义负采样方法的时候,也应该遵循这样的规则,否则会导致出错。