pytorch导入数据集
时间: 2025-04-17 13:46:16 浏览: 19
### 导入数据集
在 PyTorch 中,可以利用 `torchvision.datasets` 来加载常用的数据集,对于自定义图像数据集,则可以通过继承 `torch.utils.data.Dataset` 类来创建自己的数据集类[^2]。
为了处理特定于项目的猫狗分类问题,通常会构建一个专门的 Dataset 类。下面是一个简单的例子:
```python
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
class CatsDogsDataset(Dataset):
"""Cats and Dogs dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.annotations.iloc[idx, 0])
image = Image.open(img_name)
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
```
此代码片段展示了如何通过读取 CSV 文件中的路径和标签信息来加载图片并应用转换操作。这里假设 `cats_dogs.csv` 包含两列:一列为文件名,另一列为对应的类别标签。
接着,为了让模型能够接收这些数据,在实例化上述类之后还需要将其封装到 DataLoader 对象里以便批量迭代访问:
```python
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = CatsDogsDataset(
csv_file='path/to/cats_dogs.csv',
root_dir='path/to/cats-dogs-resized/',
transform=transform
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
```
这样就可以方便地获取批次化的输入张量用于后续训练过程了。
阅读全文
相关推荐


















