求知 文章 文库 Lib 视频 iPerson 课程 认证 咨询 工具 讲座 Modeler   Code  
会员   
 


业务架构设计
4月18-19日 在线直播



基于UML和EA进行系统分析设计
4月25-26日 北京+在线



AI 智能化软件测试方法与实践
5月23-24日 上海+在线
 
追随技术信仰

随时听讲座
每天看新闻
 
 
PyTorch 教程
1. PyTorch 简介
2. PyTorch 安装
3. PyTorch 基础
4. PyTorch 张量
5. PyTorch 神经网络基础
6. PyTorch 第一个神经网络
7. PyTorch 数据处理与加载
8. PyTorch 线性回归
9. PyTorch 卷积神经网络
10. PyTorch 循环神经网络
11. PyTorch 数据集
12. PyTorch 数据转换
 

 
目录
PyTorch 数据转换
43 次浏览
 

在 PyTorch 中,数据转换(Data Transformation) 是一种在加载数据时对数据进行处理的机制,将原始数据转换成适合模型训练的格式,主要通过 torchvision.transforms 提供的工具完成。

数据转换不仅可以实现基本的数据预处理(如归一化、大小调整等),还能帮助进行数据增强(如随机裁剪、翻转等),提高模型的泛化能力。

为什么需要数据转换?

数据预处理:

  • 调整数据格式、大小和范围,使其适合模型输入。
  • 例如,图像需要调整为固定大小、张量格式并归一化到 [0,1]。

数据增强:

  • 在训练时对数据进行变换,以增加多样性。
  • 例如,通过随机旋转、翻转和裁剪增加数据样本的变种,避免过拟合。

灵活性:

  • 通过定义一系列转换操作,可以动态地对数据进行处理,简化数据加载的复杂度。
  • 在 PyTorch 中,torchvision.transforms 模块提供了多种用于图像处理的变换操作。

基础变换操作

变换函数名称 描述 实例
transforms.ToTensor() 将PIL图像或NumPy数组转换为PyTorch张量,并自动将像素值归一化到 [0, 1]。 transform = transforms. ToTensor()
transforms.Normalize(mean, std) 对图像进行标准化,使数据符合零均值和单位方差。 transform = transforms. Normalize (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Resize(size) 调整图像尺寸,确保输入到网络的图像大小一致。 transform = transforms. Resize((256, 256))
transforms.CenterCrop(size) 从图像中心裁剪指定大小的区域。 transform = transforms. CenterCrop(224)

 

1、ToTensor

将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。

同时将像素值从 [0, 255] 归一化为 [0, 1]。

from torchvision import transforms

transform = transforms.ToTensor()

2、Normalize

对数据进行标准化,使其符合特定的均值和标准差。

通常用于图像数据,将其像素值归一化为零均值和单位方差。

transform = transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到 [-1, 1]

3、Resize

调整图像的大小。

transform = transforms.Resize((128, 128))  # 将图像调整为 128x128

4、CenterCrop

从图像中心裁剪指定大小的区域。

transform = transforms.CenterCrop(128)  # 裁剪 128x128 的区域

数据增强操作

变换函数名称 描述 实例
transforms.RandomHorizontalFlip(p) 随机水平翻转图像。 transform = transforms. RandomHorizontalFlip(p=0.5)
transforms.RandomRotation(degrees) 随机旋转图像。 transform = transforms. RandomRotation(degrees=45)
transforms.ColorJitter(brightness, contrast, saturation, hue) 调整图像的亮度、对比度、饱和度和色调。 transform = transforms. ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
transforms.RandomCrop(size) 随机裁剪指定大小的区域。 transform = transforms. RandomCrop(224)
transforms.RandomResizedCrop(size) 随机裁剪图像并调整到指定大小。 transform = transforms. RandomResizedCrop(224)

1、RandomCrop

从图像中随机裁剪指定大小。

transform = transforms.RandomCrop(128)

2、RandomHorizontalFlip

以一定概率水平翻转图像。

transform = transforms.RandomHorizontalFlip(p=0.5)  # 50% 概率翻转

3、RandomRotation

随机旋转一定角度。

transform = transforms.RandomRotation(degrees=30)  # 随机旋转 -30 到 +30 度

4、ColorJitter

随机改变图像的亮度、对比度、饱和度或色调。

transform = transforms.ColorJitter(brightness=0.5, contrast=0.5)

组合变换

变换函数名称 描述 实例
transforms.Compose() 将多个变换组合在一起,按照顺序依次应用。 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256))])

通过 transforms.Compose 将多个变换组合起来。

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

自定义转换

如果 transforms 提供的功能无法满足需求,可以通过自定义类或函数实现。

class CustomTransform:
    def __call__(self, x):
        # 这里可以自定义任何变换逻辑
        return x * 2

transform = CustomTransform()

实例

对图像数据集应用转换

加载 MNIST 数据集,并应用转换。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义转换
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 使用 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

# 查看转换后的数据
for images, labels in train_loader:
    print("图像张量大小:", images.size())  # [batch_size, 1, 128, 128]
    break

输出结果为:

图像张量大小: torch.Size([32, 1, 128, 128])

可视化转换效果

以下代码展示了原始数据和经过转换后的数据对比。

import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import datasets, transforms


# 原始和增强后的图像可视化
transform_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor()
])

# 加载数据集
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_augment)

# 显示图像
def show_images(dataset):
    fig, axs = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        image, label = dataset[i]
        axs[i].imshow(image.squeeze(0), cmap='gray')  # 将 (1, H, W) 转为 (H, W)
        axs[i].set_title(f"Label: {label}")
        axs[i].axis('off')
    plt.show()

show_images(dataset)

显示如下所示:


您可以捐助,支持我们的公益事业。

1元 10元 50元





认证码: 验证码,看不清楚?请点击刷新验证码 必填



43 次浏览