网站首页 > 教程文章 正文
学习目标
完成本课程后,你将能够:
- 理解 PyTorch 中 Dataset 和 DataLoader 的作用。
- 创建并使用自定义数据集。
- 从网上下载并加载标准数据集(MNIST 和 CIFAR-10)。
- 使用 DataLoader 按批次加载数据。
- 可视化一个批次的数据。
关键术语
- Dataset(数据集):PyTorch 的 torch.utils.data.Dataset 类,用于表示数据集,提供访问样本和标签的接口。
- DataLoader(数据加载器):PyTorch 的 torch.utils.data.DataLoader 工具,用于批次加载、数据打乱和并行处理。
- Batch(批次):一次处理的样本子集,减少内存占用,加速计算。
- Transform(变换):数据预处理函数,如转换为张量、归一化。
- Tensor(张量):PyTorch 的多维数组,用于存储数据和模型参数。
前置要求
- 安装 PyTorch 和 torchvision:bash
- pip install torch torchvision
- 安装 matplotlib(用于可视化):bash
- pip install matplotlib
- 基本 Python 和 NumPy 知识。
理解数据集格式
PyTorch 数据集是一组样本,每个样本包含:
- 特征:输入数据(如图像张量)。
- 标签:目标值(如分类标签)。
数据集类型
- 标准数据集(如 MNIST、CIFAR-10):通过 torchvision 下载,自动转换为张量。
- 自定义数据集:用户创建,需实现 Dataset 类,加载本地文件或生成数据。
1. 生成并使用自定义数据集
步骤 1:生成模拟数据集
我们将生成一个简单的自定义数据集,包含:
- 100 张 28x28 像素的灰度“图像”(随机噪声,模拟手写数字)。
- 对应的标签(0–9 的随机整数)。
python
import os
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
# 创建目录
os.makedirs("custom_data/images", exist_ok=True)
# 生成 100 张模拟图像和标签
num_samples = 100
labels = np.random.randint(0, 10, size=num_samples) # 随机标签 (0-9)
for i in range(num_samples):
# 生成随机灰度图像 (28x28)
img = np.random.rand(28, 28) * 255
img = Image.fromarray(img.astype(np.uint8))
img.save(f"custom_data/images/img_{i}.png")
# 保存标签到文件
with open("custom_data/labels.txt", "a") as f:
f.write(f"img_{i}.png {labels[i]}\\n")
步骤 2:定义自定义数据集
实现 CustomImageDataset 类来加载生成的图像和标签。
python
from torchvision import transforms
class CustomImageDataset(Dataset):
def __init__(self, img_dir, label_file, transform=None):
self.img_dir = img_dir
self.transform = transform
# 读取标签文件
self.labels = []
with open(label_file, "r") as f:
for line in f:
img_name, label = line.strip().split()
self.labels.append((img_name, int(label)))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name, label = self.labels[idx]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert("L") # 灰度图像
if self.transform:
image = self.transform(image)
return image, label
步骤 3:加载并可视化自定义数据集
使用 DataLoader 加载数据并可视化一个批次。
python
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 定义变换
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
# 创建数据集
custom_dataset = CustomImageDataset(
img_dir="custom_data/images",
label_file="custom_data/labels.txt",
transform=transform
)
# 创建 DataLoader
custom_loader = DataLoader(custom_dataset, batch_size=16, shuffle=True)
# 获取一个批次
images, labels = next(iter(custom_loader))
print(f"自定义批次形状: {images.shape}") # [16, 1, 28, 28]
print(f"标签: {labels[:10]}")
# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i].squeeze(), cmap="gray")
axes[i].set_title(f"标签: {labels[i].item()}")
axes[i].axis("off")
plt.show()
输出说明
- 批次形状:[16, 1, 28, 28] 表示 16 张图像,1 个通道,28x28 像素。
- 图像:随机噪声图像,模拟手写数字。
- 标签:随机生成的 0–9 整数。
2. 从网上下载并加载标准数据集
示例 1:MNIST 数据集
MNIST 是一个手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像。
python
# 加载 MNIST
train_dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# 获取一个批次
images, labels = next(iter(train_loader))
print(f"MNIST 批次形状: {images.shape}") # [16, 1, 28, 28]
print(f"标签: {labels[:10]}")
# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
axes[i].imshow(images[i].squeeze(), cmap="gray")
axes[i].set_title(f"标签: {labels[i].item()}")
axes[i].axis("off")
plt.show()
示例 2:CIFAR-10 数据集
CIFAR-10 包含 60,000 张 32x32 像素的彩色图像,分为 10 个类别(例如猫、狗、飞机)。
python
# 定义 CIFAR-10 变换(RGB 图像需要 3 通道归一化)
cifar_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # RGB 通道归一化
])
# 加载 CIFAR-10
cifar_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=cifar_transform
)
cifar_loader = DataLoader(cifar_dataset, batch_size=16, shuffle=True)
# 获取一个批次
images, labels = next(iter(cifar_loader))
print(f"CIFAR-10 批次形状: {images.shape}") # [16, 3, 32, 32]
print(f"标签: {labels[:10]}")
# 可视化前 6 张图像
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
class_names = ["飞机", "汽车", "鸟", "猫", "鹿", "狗", "青蛙", "马", "船", "卡车"]
for i in range(6):
img = images[i].permute(1, 2, 0) * 0.5 + 0.5 # 反归一化并调整通道顺序
axes[i].imshow(img.numpy())
axes[i].set_title(f"标签: {class_names[labels[i].item()]}")
axes[i].axis("off")
plt.show()
输出说明
- MNIST:灰度图像,形状 [16, 1, 28, 28],标签为 0–9。
- CIFAR-10:彩色图像,形状 [16, 3, 32, 32],标签为 0–9(对应 10 个类别)。
- 可视化:CIFAR-10 需要反归一化和通道调整(permute)以正确显示 RGB 图像。
练习任务
- 修改自定义数据集: 增加样本数量到 200,并生成 32x32 像素的图像。 更改批次大小为 32,观察输出形状。
- 扩展标准数据集: 在 MNIST 或 CIFAR-10 的 DataLoader 中禁用 shuffle,检查样本顺序。 显示 12 张 CIFAR-10 图像(3x4 网格)。
- 创建真实自定义数据集: 收集 10 张本地图像(例如猫狗照片)。 创建 labels.txt(格式:image_name label),并加载测试。
资源推荐
- PyTorch 数据加载教程
- Torchvision 数据集
- 自定义数据集指南
- CIFAR-10 数据集
优化说明
- 结构清晰:先展示自定义数据集的生成和使用,再介绍标准数据集,逻辑更流畅。
- 多样性:增加了 CIFAR-10 示例,展示如何处理彩色图像。
- 初学者友好:代码注释详细,术语解释简洁,练习任务分层。
- 实用性:生成模拟数据集无需外部文件,方便直接运行。
给初学者的提示
- 变换:灰度图像归一化使用单通道均值/标准差,RGB 图像需要三通道。
- 批次大小:从 16 或 32 开始,避免内存问题。
- 调试:经常打印 shape 和 labels 检查数据是否正确。
- 系统兼容:Windows 用户若遇到 num_workers 错误,设为 0。
猜你喜欢
- 2025-05-28 21-Python-文件操作
- 2025-05-28 为你的python程序上锁:软件序列号生成器
- 2025-05-28 用Python做个“冰墩墩雪容融”桌面部件(好玩又有趣)
- 2025-05-28 Dify工具使用全场景:通过文本生成word的指南(功能篇·第4期)
- 2025-05-28 2025年必学的Python自动化办公的15个实用脚本
- 2025-05-28 自学Python第二天
- 2025-05-28 ScalersTalk 成长会 Python 小组第 9 周学习笔记
- 2025-05-28 怎么做到的?用python制作九宫格图片,太棒了
- 2025-05-28 利用Dask构建端到端数据处理:从数据摄取到数据库加载的实战指南
- 2025-05-28 每日自动备份文件
- 最近发表
- 标签列表
-
- location.href (44)
- document.ready (36)
- git checkout -b (34)
- 跃点数 (35)
- 阿里云镜像地址 (33)
- qt qmessagebox (36)
- mybatis plus page (35)
- vue @scroll (38)
- 堆栈区别 (33)
- 什么是容器 (33)
- sha1 md5 (33)
- navicat导出数据 (34)
- 阿里云acp考试 (33)
- 阿里云 nacos (34)
- redhat官网下载镜像 (36)
- srs服务器 (33)
- pico开发者 (33)
- https的端口号 (34)
- vscode更改主题 (35)
- 阿里云资源池 (34)
- os.path.join (33)
- redis aof rdb 区别 (33)
- 302跳转 (33)
- http method (35)
- js array splice (33)