0
点赞
收藏
分享

微信扫一扫

pytorch DataLoader 随机

使用 PyTorch 的 DataLoader 加载数据时,可以通过设置 shuffle 参数为 True 来实现随机打乱数据集中的样本顺序。这样可以确保模型在训练过程中不会受到数据样本顺序的影响,提高模型的泛化能力。

下面是一个更详细的示例代码,展示如何使用 DataLoader 随机打乱数据集中的样本顺序:

import torch
from torch.utils.data import DataLoader, Dataset

# 创建自定义的数据集类
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

# 创建数据集
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
dataset = CustomDataset(data)

# 创建 DataLoader,并设置 shuffle=True
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历 DataLoader
for i, batch in enumerate(dataloader):
    print(f"Batch {i}: {batch}")

# 输出结果

在上面的代码中,我们首先定义了一个包含数据 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 的数据集对象 dataset。然后我们创建了一个 DataLoader 对象 dataloader,并将 shuffle 参数设置为 True,这样在每个 epoch 中都会随机打乱数据集中的样本顺序。最后我们遍历 DataLoader,打印出每个 batch 中的样本。可以看到每个 batch 中的样本顺序是随机的;

通过设置 shuffle 参数为 True,我们可以确保模型在训练过程中不会受到数据样本顺序的影响,提高模型的泛化能力。

举报

相关推荐

0 条评论