问题
处理自定义数据集是应用PyTorch走向工程实际的重要前提,本文将持续更新介绍自定义数据集处理一些常见方法。
方法
加载自定义数据集并获取分类数量
from torchvision.datasets import ImageFolder
train_dataset = ImageFolder('D:\\data\\FD-dataset-challenge')
class_to_idx = train_dataset.class_to_idx
num_classes = len(class_to_idx)
print(class_to_idx) # {'fire': 0, 'no_fire': 1}
print(num_classes) # 2
消除控制台不影响程序运行的警告信息
import warnings
warnings.filterwarnings('ignore')