0
点赞
收藏
分享

微信扫一扫

210316-针对类别不平衡数据集PyTorch实现每个Batch中出现所有类别及数量近似(待整理)


def prepare_dataloader(X, Y, P=None, dim='2D', batch_size=32, drop_last=True):
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from collections import Counter
from torch.utils.data.sampler import WeightedRandomSampler

class Self_Def_Dataset(Dataset):
def __init__(self, X, Y):
if dim=='2D':
X = np.reshape(X, (-1,1,32,32))
self.X = torch.Tensor(X)
self.Y = torch.LongTensor(Y)
if P==None:
self.P = torch.LongTensor(Y)
self.O = 'two_output'
else:
self.O = 'three_output'
self.P = torch.LongTensor(P)

def __getitem__(self, index):
x = self.X[index]
y = self.Y[index]
p = self.P[index]
if self.O=='two_output':
return x, y
if self.O=='three_output':
return x, y, p

def __len__(self):
return len(self.X)

# ------------------------- Weightedrandomsampler ------------------------ #
# # ! [Batch for imbalanaced dataset]
# # ! https://stackoverflow.com/questions/60812032/using-weightedrandomsampler-in-pytorch
# # ! https://towardsdatascience.com/pytorch-basics-sampling-samplers-2a0f29f0bf2a
# # ! https
# Y_dict = dict(Counter(Y))
# class_weights = [Y_dict[k]/len(Y) for (k,v) in enumerate(Y_dict)]
# weights = [class_weights[Y[i]] for i in range(len(Y))]
# sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(len(Y)))
# dataset = Self_Def_Dataset(X, Y)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, sampler=sampler)
# ------------------------- Weightedrandomsampler ------------------------ #

dataset = Self_Def_Dataset(X, Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
return dataloader


举报

相关推荐

0 条评论