惠州市建设交易中心网站,建设局主要负责什么,咨询公司面试,登录百度app文章目录 1. 简介2. 查看PyTorch自带的数据集(可视化)3. 准备材料3.1 图片数据3.2 标签数据 4. 方法 1. 简介
尽管PyTorch提供了许多自带的数据集#xff0c;如MNIST、CIFAR-10、ImageNet等#xff0c;但它们对于没有经验的用户来说#xff0c;理解数据加载器的工作原理以及… 文章目录 1. 简介2. 查看PyTorch自带的数据集(可视化)3. 准备材料3.1 图片数据3.2 标签数据 4. 方法 1. 简介
尽管PyTorch提供了许多自带的数据集如MNIST、CIFAR-10、ImageNet等但它们对于没有经验的用户来说理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集包括数据集的内容、结构、标签等信息。对于一些复杂的数据集用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类您可以更好地控制数据的加载和处理过程提高代码的灵活性、可读性和可维护性同时更好地满足模型训练的需求。
2. 查看PyTorch自带的数据集(可视化)
为了更好的定义自己的数据集我们首先查看PyTorch自带的数据集的内容代码如下
# 导入所需的库
import matplotlib.pyplot as plt # 导入Matplotlib库用于可视化
import torch # 导入PyTorch库
from torchvision.datasets import MNIST # 从torchvision中导入MNIST数据集
from torchvision import transforms # 导入transforms模块用于数据预处理
import numpy as np # 导入NumPy库# 加载MNIST数据集
train_mnist_data MNIST(root./data, # 数据集存储路径trainTrue, # 加载训练集transformtransforms.Compose([transforms.Resize(size(28, 28)), transforms.ToTensor()]), # 数据预处理操作downloadTrue) # 如果数据集不存在则自动下载# 设置要显示的样本数量
num_samples 10# 创建包含多个子图的大图窗口
fig, axes plt.subplots(1, num_samples, figsize(10, 6))# 遍历选择要显示的样本
for i in range(num_samples):# 从数据集中获取图像数据和标签image, label train_mnist_data[i]# 在子图中显示图像axes[i].imshow(image.squeeze().numpy(), cmapgray) # 使用imshow函数显示图像将张量转换为NumPy数组axes[i].set_title(fLabel: {label}) # 设置子图标题显示图像对应的标签axes[i].axis(off) # 关闭坐标轴显示# 将图像保存为PNG格式的图片文件文件名以图像的标签命名plt.imsave(f./data/mnist_images/{label}.png, image.squeeze().numpy(), cmapgray)# 显示图形窗口
plt.show()这里我们使用MNIST类加载MNIST数据集。在加载数据集时通过transform参数指定了数据预处理操作包括将图像大小调整为28x28像素并将图像转换为张量。trainTrue表示加载训练集downloadTrue表示如果数据集不存在则自动下载到指定的路径。
接下来我们选择一些样本进行可视化。我们在一个子图中显示了10个样本每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本从数据集中获取图像数据和标签并使用Matplotlib的imshow()函数将图像显示在子图中。
同时使用imsave()函数将每个图像保存为PNG格式的图片文件文件名以标签命名。最后使用plt.show()显示图形窗口显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像并将这些图像保存到指定路径下。保存的图片如下所示 通过上面程序可以看到数据集主要是由图片数据和对应的标签构成那么我们就可以用这两个主要构成成分来构建自己的数据集。
3. 准备材料
3.1 图片数据
这里我们就用刚才保存的十张图片即 当然你也可以准备其它的图片并给图片分别命名为“0.png, 1.png, …”。
这里十张图片的相对路径为
imgs_path ./data/mnist_images注你们要根据自己存储的路径来给定。 3.2 标签数据
创建一个txt文件为每一幅图片指定标签数据如下所示 这里txt文件的相对路径为
labels_path labels.txt4. 方法
在PyTorch中您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset类并且实现两个主要的方法__len__ 和 __getitem__。__len__方法应该返回数据集的长度而__getitem__方法则根据给定的索引返回数据集中的样本。
下面我们展示如何创建一个自定义的数据集类
import os # 导入os模块用于操作文件路径
from PIL import Image # 导入PIL库中的Image模块用于图像处理
import torch # 导入PyTorch库
from torch.utils.data import Dataset # 从torch.utils.data模块导入Dataset类用于定义自定义数据集
from torchvision import transforms # 导入transforms模块用于数据预处理
import numpy as np # 导入NumPy库用于数值处理
import matplotlib.pyplot as plt # 导入Matplotlib库用于可视化class CustomDataset(Dataset):def __init__(self, image_dir, label_file, transformNone):super().__init__() # 调用父类的构造函数self.image_dir image_dir # 图像数据的路径self.label_file label_file # 标签文本的路径self.transform transform # 数据预处理操作self.samples self._load_samples() # 加载数据集样本信息def _load_samples(self):samples [] # 存储样本信息的列表with open(self.label_file, r) as f: # 打开标签文本文件for line in f: # 逐行读取标签文本文件中的内容image_name, label line.strip().split(,) # 根据逗号分隔每行内容获取图像文件名和标签image_path os.path.join(self.image_dir, image_name) # 拼接图像文件的完整路径samples.append((image_path, int(label))) # 将图像路径和标签组成元组加入样本列表return samples # 返回样本列表def __len__(self):return len(self.samples) # 返回数据集样本的数量def __getitem__(self, index):image_path, label self.samples[index] # 获取指定索引处的图像路径和标签image Image.open(image_path).convert(L) # 打开图像文件并将其转换为灰度图像if self.transform: # 如果定义了数据预处理操作image self.transform(image) # 对图像进行预处理操作return image, label # 返回预处理后的图像和标签# 设置图片数据路径和标签文本路径
image_dir ./data/mnist_images # 图像数据的路径
label_file labels.txt # 标签文本的路径# 定义数据预处理操作根据需要添加其他预处理操作
transform transforms.Compose([transforms.Resize((28, 28)), # 调整图像大小transforms.ToTensor(), # 将图像转换为张量
])# 创建自定义数据集实例
custom_dataset CustomDataset(image_dir, label_file, transformtransform)# 创建数据加载器
data_loader torch.utils.data.DataLoader(custom_dataset, batch_size1, shuffleFalse)# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:# 使用squeeze()函数去除图像张量中的单维度将图像数据转换为NumPy数组并存储在变量image中image batch_images.squeeze().numpy()# 使用imshow()函数显示图像cmapgray指定使用灰度色彩映射plt.imshow(image, cmapgray)# 设置图像标题显示图像对应的标签使用f-string格式化字符串将batch_labels转换为Python标量并获取其值plt.title(fLabel: {batch_labels.item()})# 关闭坐标轴显示即不显示坐标轴plt.axis(off)# 显示图形窗口plt.show()
这段代码实现了加载自定义数据集并使用 PyTorch 的 DataLoader 将数据加载成批次然后逐批次地展示图像。