torch.utils.data.DataLoader
本节讲述collate_fn使用。
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
torch._C._log_api_usage_once("python.data_loader"))
collate_fn官网介绍:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.意思是对传入的dataset进行额外处理。
如下官网的一个collate模块,从中可以看出,对数据做各种转化。
default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel

本文详细解析了PyTorch中DataLoader组件及其collate_fn参数的使用方法,通过实例展示了如何自定义collate_fn以处理不同大小的图像批次,并提供了将数据集转换为Tensor和创建mini-batch的具体步骤。
最低0.47元/天 解锁文章
5073

被折叠的 条评论
为什么被折叠?



