云计算、AI、云原生、大数据等一站式技术学习平台

网站首页 > 教程文章 正文

Pytorch入门-Day9:数据加载(Dataset与DataLoader)

jxf315 2025-05-28 18:08:31 教程文章 5 ℃

学习目标

完成本课程后,你将能够:

  1. 理解 PyTorch 中 Dataset 和 DataLoader 的作用。
  2. 创建并使用自定义数据集。
  3. 从网上下载并加载标准数据集(MNIST 和 CIFAR-10)。
  4. 使用 DataLoader 按批次加载数据。
  5. 可视化一个批次的数据。

关键术语

  1. Dataset(数据集):PyTorch 的 torch.utils.data.Dataset 类,用于表示数据集,提供访问样本和标签的接口。
  2. DataLoader(数据加载器):PyTorch 的 torch.utils.data.DataLoader 工具,用于批次加载、数据打乱和并行处理。
  3. Batch(批次):一次处理的样本子集,减少内存占用,加速计算。
  4. Transform(变换):数据预处理函数,如转换为张量、归一化。
  5. Tensor(张量):PyTorch 的多维数组,用于存储数据和模型参数。

前置要求

  • 安装 PyTorch 和 torchvision:bash
  • pip install torch torchvision
  • 安装 matplotlib(用于可视化):bash
  • pip install matplotlib
  • 基本 Python 和 NumPy 知识。

理解数据集格式

PyTorch 数据集是一组样本,每个样本包含:

  • 特征:输入数据(如图像张量)。
  • 标签:目标值(如分类标签)。

数据集类型

  1. 标准数据集(如 MNIST、CIFAR-10):通过 torchvision 下载,自动转换为张量。
  2. 自定义数据集:用户创建,需实现 Dataset 类,加载本地文件或生成数据。

1. 生成并使用自定义数据集

步骤 1:生成模拟数据集

我们将生成一个简单的自定义数据集,包含:

  • 100 张 28x28 像素的灰度“图像”(随机噪声,模拟手写数字)。
  • 对应的标签(0–9 的随机整数)。

python

import os
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image

# 创建目录
os.makedirs("custom_data/images", exist_ok=True)

# 生成 100 张模拟图像和标签
num_samples = 100
labels = np.random.randint(0, 10, size=num_samples)  # 随机标签 (0-9)
for i in range(num_samples):
    # 生成随机灰度图像 (28x28)
    img = np.random.rand(28, 28) * 255
    img = Image.fromarray(img.astype(np.uint8))
    img.save(f"custom_data/images/img_{i}.png")
    # 保存标签到文件
    with open("custom_data/labels.txt", "a") as f:
        f.write(f"img_{i}.png {labels[i]}\\n")

步骤 2:定义自定义数据集

实现 CustomImageDataset 类来加载生成的图像和标签。

python

from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, label_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        # 读取标签文件
        self.labels = []
        with open(label_file, "r") as f:
            for line in f:
                img_name, label = line.strip().split()
                self.labels.append((img_name, int(label)))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name, label = self.labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("L")  # 灰度图像
        if self.transform:
            image = self.transform(image)
        return image, label

步骤 3:加载并可视化自定义数据集

使用 DataLoader 加载数据并可视化一个批次。

python

from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义变换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 创建数据集
custom_dataset = CustomImageDataset(
    img_dir="custom_data/images",
    label_file="custom_data/labels.txt",
    transform=transform
)

# 创建 DataLoader
custom_loader = DataLoader(custom_dataset, batch_size=16, shuffle=True)

# 获取一个批次
images, labels = next(iter(custom_loader))
print(f"自定义批次形状: {images.shape}")  # [16, 1, 28, 28]
print(f"标签: {labels[:10]}")

# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
    axes[i].imshow(images[i].squeeze(), cmap="gray")
    axes[i].set_title(f"标签: {labels[i].item()}")
    axes[i].axis("off")
plt.show()

输出说明

  • 批次形状:[16, 1, 28, 28] 表示 16 张图像,1 个通道,28x28 像素。
  • 图像:随机噪声图像,模拟手写数字。
  • 标签:随机生成的 0–9 整数。

2. 从网上下载并加载标准数据集

示例 1:MNIST 数据集

MNIST 是一个手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像。

python

# 加载 MNIST
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# 获取一个批次
images, labels = next(iter(train_loader))
print(f"MNIST 批次形状: {images.shape}")  # [16, 1, 28, 28]
print(f"标签: {labels[:10]}")

# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
    axes[i].imshow(images[i].squeeze(), cmap="gray")
    axes[i].set_title(f"标签: {labels[i].item()}")
    axes[i].axis("off")
plt.show()

示例 2:CIFAR-10 数据集

CIFAR-10 包含 60,000 张 32x32 像素的彩色图像,分为 10 个类别(例如猫、狗、飞机)。

python

# 定义 CIFAR-10 变换(RGB 图像需要 3 通道归一化)
cifar_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # RGB 通道归一化
])

# 加载 CIFAR-10
cifar_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=cifar_transform
)

cifar_loader = DataLoader(cifar_dataset, batch_size=16, shuffle=True)

# 获取一个批次
images, labels = next(iter(cifar_loader))
print(f"CIFAR-10 批次形状: {images.shape}")  # [16, 3, 32, 32]
print(f"标签: {labels[:10]}")

# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
class_names = ["飞机", "汽车", "鸟", "猫", "鹿", "狗", "青蛙", "马", "船", "卡车"]
for i in range(6):
    img = images[i].permute(1, 2, 0) * 0.5 + 0.5  # 反归一化并调整通道顺序
    axes[i].imshow(img.numpy())
    axes[i].set_title(f"标签: {class_names[labels[i].item()]}")
    axes[i].axis("off")
plt.show()

输出说明

  • MNIST:灰度图像,形状 [16, 1, 28, 28],标签为 0–9。
  • CIFAR-10:彩色图像,形状 [16, 3, 32, 32],标签为 0–9(对应 10 个类别)。
  • 可视化:CIFAR-10 需要反归一化和通道调整(permute)以正确显示 RGB 图像。

练习任务

  1. 修改自定义数据集: 增加样本数量到 200,并生成 32x32 像素的图像。 更改批次大小为 32,观察输出形状。
  2. 扩展标准数据集: 在 MNIST 或 CIFAR-10 的 DataLoader 中禁用 shuffle,检查样本顺序。 显示 12 张 CIFAR-10 图像(3x4 网格)。
  3. 创建真实自定义数据集: 收集 10 张本地图像(例如猫狗照片)。 创建 labels.txt(格式:image_name label),并加载测试。

资源推荐

  • PyTorch 数据加载教程
  • Torchvision 数据集
  • 自定义数据集指南
  • CIFAR-10 数据集

优化说明

  1. 结构清晰:先展示自定义数据集的生成和使用,再介绍标准数据集,逻辑更流畅。
  2. 多样性:增加了 CIFAR-10 示例,展示如何处理彩色图像。
  3. 初学者友好:代码注释详细,术语解释简洁,练习任务分层。
  4. 实用性:生成模拟数据集无需外部文件,方便直接运行。

给初学者的提示

  • 变换:灰度图像归一化使用单通道均值/标准差,RGB 图像需要三通道。
  • 批次大小:从 16 或 32 开始,避免内存问题。
  • 调试:经常打印 shape 和 labels 检查数据是否正确。
  • 系统兼容:Windows 用户若遇到 num_workers 错误,设为 0。

Tags:

最近发表
标签列表