【dgl学习】dgl.dataloader中链接预测任务的负采样方法neg_sampler详解

dgl分批次进行链接预测任务时,会用到dataloader类,可以参考一下代码定义自己的dataloader。

dataloader = dgl.dataloading.DistEdgeDataLoader(
# dataloader = dgl.dataloading.EdgeDataLoader()
    graph, train_seeds, sampler,
    negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),
    # negative_sampler=NegativeSampler(graph,args.negsample),
    batch_size=4,
    shuffle=True,
    drop_last=False,
    pin_memory=True,
    num_workers=False)

在设置自己的负采样方法时,总是存在问题,详细看了这里的负采样方法negative_sampler,作为记录。通过一个小例子理解一下该方法的使用:

g = dgl.graph(([0, 1, 2], [1, 2, 3]))
neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(4)
neg_sampler(g, torch.tensor([0]))
>>>(tensor([0, 0, 0, 0]), tensor([3, 0, 2, 1]))

由上述示例可以看出,neg_sampler的输出结果是一个<class ‘tuple’>,头部为源节点,尾部为目标节点。dgl.dataloading.negative_sampler.PerSourceUniform(4)中的参数N(例子中为4),是负采样的个数,neg_sampler使用时,需要输入图(g)还有选择的源节点(示例中为:torch.tensor([0]))。
主要了解到neg_sampler的结果是由两个tensor组成的tuple,所以自己在定义负采样方法的时候,也应该遵循这样的规则,否则会导致出错。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值