报错batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
时间: 2025-03-07 11:16:20 浏览: 141
### 解决 PyTorch DataLoader 类型错误
当使用 PyTorch 的 `DataLoader` 加载数据时,如果遇到如下错误:
```
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
```
这通常意味着自定义的数据集中返回的数据不是张量或其他支持的类型。为了修正这个问题,可以采取以下措施。
#### 数据预处理函数调整
确保在构建数据集类时,在 `__getitem__()` 方法内将图像转换成张量。可以通过调用 torchvision.transforms.ToTensor() 来实现这一点[^3]。
```python
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, ...):
self.transform = transform
def __getitem__(self, index):
img_path = ...
label = ...
image = Image.open(img_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
```
通过上述方法,任何传给默认收集器(`default_collate`)的对象都将被正确地转化为张量形式,从而避免了原始类型的不兼容问题[^4]。
#### 自定义 collate_fn 函数
另一种解决方案是在创建 `DataLoader` 实例的时候提供一个自定义的 `collate_fn` 参数。这个参数允许用户指定如何组合一批样本成为一个批次(batch)[^1]。
```python
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
# 假设data是一个列表中的图片对象(PIL),target是标签
transformed_data = torch.stack([transforms.ToTensor()(d) for d in data])
return (transformed_data, target)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True,
collate_fn=my_collate # 使用自定义的collate函数
)
```
这种方法提供了更大的灵活性,特别是在处理复杂或特殊结构化输入的情况下非常有用。
#### 验证代码片段
下面是一段简单的验证代码,用于确认修改后的逻辑能够正常工作并打印出每个迭代周期内的输入尺寸和标签数量[^5]。
```python
for epoch in range(2):
for i, data in enumerate(train_loader):
inputs, labels = data
print(f"Epoch {epoch}, Batch {i}: Inputs shape={inputs.shape} Labels count={len(labels)}")
```
阅读全文
相关推荐















