0
点赞
收藏
分享

微信扫一扫

Pytorch Dataset

Pytorch Dataset_数据

'''
Dataset:
提供读取数据和其标签的方式:
- 获取每条数据和标签
- 告诉我们总共有多少条数据
'''
from torch.utils.data import Dataset
from PIL import Image
import os


class DataSet(Dataset):
def __init__(self, type):
# self.path = os.path.abspath('.')
self.root_path = os.getcwd() # 获取当前工作目录
self.type = type # ants or bees
self.paths = os.path.join(self.root_path, '数据集\\hymenoptera_data\\hymenoptera_data\\train', self.type)

def __getitem__(self, idx):
img_path = os.path.join(self.paths, os.listdir(self.paths)[idx])
img = Image.open(img_path)
label = self.type
return img, label

def __len__(self):
return len(os.listdir(self.paths))


ant_dataset = DataSet('ants')
bee_dataset = DataSet('bees')
data = ant_dataset + bee_dataset

Pytorch Dataset_深度学习_02

'''
Dataset:
提供读取数据和其标签的方式:
- 获取每条数据和标签
- 告诉我们总共有多少条数据
'''
from torch.utils.data import Dataset
from PIL import Image
import os


class DataSet(Dataset):
def __init__(self, type):
# self.path = os.path.abspath('.')
self.root_path = os.getcwd() # 获取当前工作目录
self.type = type # ants or bees
self.paths = os.path.join(self.root_path, '数据集\\练手数据集\\train', self.type)

def __getitem__(self, idx):
img_name = os.listdir(self.paths)[idx]
img_path = os.path.join(self.paths, img_name)
img = Image.open(img_path)
# 将标签从文件中读取
# label = self.type.split('_')[0]
with open(os.path.join(os.path.join(self.root_path, '数据集\\练手数据集\\train'), 'bees_label', '{}.txt'.format(img_name.strip('.jpg'))), mode='rt', encoding='utf8') as f:
label = f.read().strip()
return img, label

def __len__(self):
return len(os.listdir(self.paths))


ant_dataset = DataSet('bees_image')
print(ant_dataset[0])


举报

相关推荐

0 条评论