求知 文章 文库 Lib 视频 iPerson 课程 认证 咨询 工具 讲座 Modeler   Code  
会员   
 


业务架构设计
4月18-19日 在线直播



基于UML和EA进行系统分析设计
4月25-26日 北京+在线



AI 智能化软件测试方法与实践
5月23-24日 上海+在线
 
追随技术信仰

随时听讲座
每天看新闻
 
 
PyTorch 教程
1. PyTorch 简介
2. PyTorch 安装
3. PyTorch 基础
4. PyTorch 张量
5. PyTorch 神经网络基础
6. PyTorch 第一个神经网络
7. PyTorch 数据处理与加载
8. PyTorch 线性回归
9. PyTorch 卷积神经网络
 

 
目录
PyTorch 数据处理与加载
24 次浏览
3次  

在 PyTorch 中,处理和加载数据是深度学习训练过程中的关键步骤。

为了高效地处理数据,PyTorch 提供了强大的工具,包括 torch.utils.data.Dataset 和 torch.utils.data.DataLoader,帮助我们管理数据集、批量加载和数据增强等任务。

PyTorch 数据处理与加载的介绍:

  • 自定义 Dataset:通过继承 torch.utils.data.Dataset 来加载自己的数据集。
  • DataLoader:DataLoader 按批次加载数据,支持多线程加载并进行数据打乱。
  • 数据预处理与增强:使用 torchvision.transforms 进行常见的图像预处理和增强操作,提高模型的泛化能力。
  • 加载标准数据集:torchvision.datasets 提供了许多常见的数据集,简化了数据加载过程。
  • 多个数据源:通过组合多个 Dataset 实例来处理来自不同来源的数据。

自定义 Dataset

torch.utils.data.Dataset 是一个抽象类,允许你从自己的数据源中创建数据集。

我们需要继承该类并实现以下两个方法:

  • __len__(self):返回数据集中的样本数量。
  • __getitem__(self, idx):通过索引返回一个样本。

假设我们有一个简单的 CSV 文件或一些列表数据,我们可以通过继承 Dataset 类来创建自己的数据集。

import torch
from torch.utils.data import Dataset

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, X_data, Y_data):
        """
        初始化数据集,X_data 和 Y_data 是两个列表或数组
        X_data: 输入特征
        Y_data: 目标标签
        """

        self.X_data = X_data
        self.Y_data = Y_data

    def __len__(self):
        """返回数据集的大小"""
        return len(self.X_data)

    def __getitem__(self, idx):
        """返回指定索引的数据"""
        x = torch.tensor(self.X_data[idx], dtype=torch.float32)  # 转换为 Tensor
        y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
        return x, y

# 示例数据
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]  # 输入特征
Y_data = [1, 0, 1, 0]  # 目标标签

# 创建数据集实例
dataset = MyDataset(X_data, Y_data)

使用 DataLoader 加载数据

DataLoader 是 PyTorch 提供的一个重要工具,用于从 Dataset 中按批次(batch)加载数据。

DataLoader 允许我们批量读取数据并进行多线程加载,从而提高训练效率。

from torch.utils.data import DataLoader

# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 打印加载的数据
for epoch in range(1):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f'Batch {batch_idx + 1}:')
        print(f'Inputs: {inputs}')
        print(f'Labels: {labels}')
  • batch_size: 每次加载的样本数量。
  • shuffle: 是否对数据进行洗牌,通常训练时需要将数据打乱。
  • drop_last: 如果数据集中的样本数不能被 batch_size 整除,设置为 True 时,丢弃最后一个不完整的 batch。

输出:

Batch 1:
Inputs: tensor([[3., 4.], [1., 2.]])
Labels: tensor([0., 1.])
Batch 2:
Inputs: tensor([[7., 8.], [5., 6.]])
Labels: tensor([0., 1.])

每次循环中,DataLoader 会返回一个批次的数据,包括输入特征(inputs)和目标标签(labels)。


预处理与数据增强

数据预处理和增强对于提高模型的性能至关重要。

PyTorch 提供了 torchvision.transforms 模块来进行常见的图像预处理和增强操作,如旋转、裁剪、归一化等。

常见的图像预处理操作:

import torchvision.transforms as transforms
from PIL import Image

# 定义数据预处理的流水线
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 将图像调整为 128x128
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载图像
image = Image.open('image.jpg')

# 应用预处理
image_tensor = transform(image)
print(image_tensor.shape)  # 输出张量的形状
  • transforms.Compose():将多个变换操作组合在一起。
  • transforms.Resize():调整图像大小。
  • transforms.ToTensor():将图像转换为 PyTorch 张量,值会被归一化到 [0, 1] 范围。
  • transforms.Normalize():标准化图像数据,通常使用预训练模型时需要进行标准化处理。

图像数据增强

数据增强技术通过对训练数据进行随机变换,增加数据的多样性,帮助模型更好地泛化。例如,随机翻转、旋转、裁剪等。

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomRotation(30),  # 随机旋转 30 度
    transforms.RandomResizedCrop(128),  # 随机裁剪并调整为 128x128
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

这些数据增强方法可以通过 transforms.Compose() 组合使用,保证每个图像在训练时具有不同的变换。


加载图像数据集

对于图像数据集,torchvision.datasets 提供了许多常见数据集(如 CIFAR-10、ImageNet、MNIST 等)以及用于加载图像数据的工具。

加载 MNIST 数据集:

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 对灰度图像进行标准化
])

# 下载并加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 迭代训练数据
for inputs, labels in train_loader:
    print(inputs.shape)  # 每个批次的输入数据形状
    print(labels.shape)  # 每个批次的标签形状
  • datasets.MNIST() 会自动下载 MNIST 数据集并加载。
  • transform 参数允许我们对数据进行预处理。
  • train=True 和 train=False 分别表示训练集和测试集。

用多个数据源(Multi-source Dataset)

如果你的数据集由多个文件、多个来源(例如多个图像文件夹)组成,可以通过继承 Dataset 类自定义加载多个数据源。

PyTorch 提供了 ConcatDataset 和 ChainDataset 等类来连接多个数据集。

例如,假设我们有多个图像文件夹的数据,可以将它们合并为一个数据集:

from torch.utils.data import ConcatDataset

# 假设 dataset1 和 dataset2 是两个 Dataset 对象
combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

 


您可以捐助,支持我们的公益事业。

1元 10元 50元





认证码: 验证码,看不清楚?请点击刷新验证码 必填



24 次浏览
3次