ISIC2017 数据集读取

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
from collections import Counter

class ISIC2017Dataset(Dataset):
    def __init__(self, root_dir, labels_df, transform=None):
        """
        Args:
            root_dir (str): 图像存储的目录路径。
            labels_df (pd.DataFrame): 包含每个图像ID及其标签的DataFrame。
            transform (callable, optional): 需要应用于每个样本的转换。
        """
        self.root_dir = root_dir
        self.labels_df = labels_df
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.labels_df.iloc[idx, 0] + '.jpg')
        image = Image.open(img_name).convert('RGB')  # 打开图像并转换为RGB模式
        label_row = self.labels_df.iloc[idx, 1:].values.astype('float32')  # 获取标签
        if label_row[0] == 1:
            label = 0  # melanoma
        elif label_row[1] == 1:
            label = 1  # seborrheic_keratosis
        else:
            label = 2  # nevus/benign pigmented lesion

        if self.transform:
            image = self.transform(image)  # 应用图像转换操作

        return image, torch.tensor(label)  # 返回图像和标签

# 使用 'if __name__ == "__main__":' 来包裹主程序,以避免多进程问题
if __name__ == "__main__":
    # 路径设置(请根据实际情况修改路径)
    train_image_dir = r'C:\Users\10596\Desktop\dataset\ISIC2017\ISIC2017\ISIC-2017_Training_Data'  # 训练集图像文件夹路径
    test_image_dir = r'C:\Users\10596\Desktop\dataset\ISIC2017\ISIC2017\ISIC-2017_Test_v2_Data'  # 测试集图像文件夹路径
    training_labels_path = r'C:\Users\10596\Desktop\dataset\ISIC2017\ISIC2017\ISIC-2017_Training_Part3_GroundTruth.csv'  # 训练集标签文件路径
    testing_labels_path = r'C:\Users\10596\Desktop\dataset\ISIC2017\ISIC2017\ISIC-2017_Test_v2_Part3_GroundTruth.csv'  # 测试集标签文件路径

    # 读取CSV文件中的标签
    training_labels = pd.read_csv(training_labels_path)
    testing_labels = pd.read_csv(testing_labels_path)

    # 输出前几行数据,确认是否正确读取
    print("Training labels:")
    print(training_labels.head())
    print("Testing labels:")
    print(testing_labels.head())

    # 定义数据预处理操作(例如:调整图像大小、归一化等)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # 调整图像大小为224x224
        transforms.ToTensor(),  # 转换为PyTorch张量
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
    ])

    # 创建训练集和测试集Dataset
    train_dataset = ISIC2017Dataset(root_dir=train_image_dir, labels_df=training_labels, transform=transform)
    test_dataset = ISIC2017Dataset(root_dir=test_image_dir, labels_df=testing_labels, transform=transform)

    # 创建DataLoader
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # 统计类别分布
    def get_class_distribution(labels_df):
        distribution = {"melanoma": 0, "seborrheic_keratosis": 0, "nevus": 0}
        for _, row in labels_df.iterrows():
            if row["melanoma"] == 1:
                distribution["melanoma"] += 1
            elif row["seborrheic_keratosis"] == 1:
                distribution["seborrheic_keratosis"] += 1
            else:
                distribution["nevus"] += 1
        return distribution

    # 训练集类别分布
    train_distribution = get_class_distribution(training_labels)
    print("Training set class distribution:")
    print(train_distribution)

    # 测试集类别分布
    test_distribution = get_class_distribution(testing_labels)
    print("Testing set class distribution:")
    print(test_distribution)

    # 示例:如何遍历DataLoader
    for images, labels in train_loader:
        print(f"Images batch shape: {images.size()}")
        print(f"Labels batch shape: {labels.size()}")
        break  # 这里只展示第一个batch,实际训练时会遍历所有batch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值