0
点赞
收藏
分享

微信扫一扫

LibFewShot:小样本学习与细粒度分类(二) -- 工程应用、单图测试

贵州谢高低 2022-03-16 阅读 98
学习分类

原理

已有:

       训练好的相似度比对模型 M

       新来的图片集 P 其中单张为p

       已知的类别支持集 S 其中单张为 s

方法:

       比较每一对 p 与 s ,记录相似度,针对每个 p,输出相似度最高的 s 的类别,以及置信度。

       配置model,配置emb_func,配置transforms,配置device

       读取图片,提取特征,调用模型预测。

过程

       通过在源代码中截取与逻辑设计,实现应用。

代码

# -*- coding: utf-8 -*-
# write for testing a pic or a pic list with a support set
# hhr 
import sys
sys.dont_write_bytecode = True
import os
from PIL import Image
from torchvision import transforms
from core.config import Config
from core import Test
import os
import csv
import pickle
from logging import getLogger
from torch.utils.data import Dataset
from time import time
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import core.model as arch
from core.data import get_dataloader
from core.utils import (
    init_logger,
    prepare_device,
    init_seed,
    AverageMeter,
    count_parameters,
    ModelType,
    TensorboardWriter,
    mean_confidence_interval,
    get_local_time,
    get_instance,
)
import copy
import numpy as np
import torch
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
from torch import nn
from torch.nn import functional as F
from core.utils import accuracy
from core.model.finetuning.finetuning_model import FinetuningModel
from core.model.loss import DistillKLLoss
import sys # 测试

# 配置参数  测试图片文件夹  训练结果文件夹
Q_path = '/home/hhr/hrmnt/work/LibFewShot-main/FSL_Dataset/fgvc-aircraft-2013b/Q/'
PATH = "./results/RFSModel-fgvc-aircraft-2013b-resnet18-5-1-Jan-12-2022-15-16-57"
VAR_DICT = {
    "test_epoch": 5,
    "device_ids": "2",
    "n_gpu": 1,
    "test_episode": 500,
    "episode_size": 1,
    "test_way": 5,
}
# "test_way": 6,
MEAN = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] # tranforms
STD = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]

def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        with Image.open(f) as img:
            # print('img:', img)
            # print('img.convert("RGB"):', img.convert("RGB"))
            return img.convert("RGB")

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)



# 测试主程序
if __name__ == "__main__":
    # return config
    config = Config(os.path.join(PATH, "config.yaml"), VAR_DICT
                ).get_config_dict()  # 读取某一次的训练配置 config  base


    # return logger
    logger = getLogger(__name__)
    logger.info(config) # 配置logger


    # return result_path
    result_path = PATH # 配置result_path


    # return viz_path, state_dict_path
    # self.viz_path, self.state_dict_path = self._init_files(config)
    # Init result_path(log_path, viz_path) from the config dict.
    if result_path is not None: 
        result_path = result_path
    else:
        result_dir = "{}-{}-{}-{}-{}".format(
            config["classifier"]["name"],
            # you should ensure that data_root name contains its true name
            config["data_root"].split("/")[-1],
            config["backbone"]["name"],
            config["way_num"],
            config["shot_num"],
        ) # 配置result_path
        result_path = os.path.join(config["result_root"], result_dir)
    # logger.log("Result DIR: " + result_path)
    log_path = os.path.join(result_path, "log_files")
    viz_path = os.path.join(log_path, "tfboard_files") # 配置
    init_logger(
        config["log_level"],
        log_path,
        config["classifier"]["name"],
        config["backbone"]["name"],
        is_train=False,
    )
    state_dict_path = os.path.join(result_path, "checkpoints", "model_best.pth") # 配置


    # 配置writer
    writer = TensorboardWriter(viz_path)


    # return device, list_ids
    # self.device, self.list_ids = self._init_device(config)
    # Init the devices from the config file.
    init_seed(config["seed"], config["deterministic"])  # 从训练配置中加载 设备 device
    device, list_ids = prepare_device(config["device_ids"], config["n_gpu"])  # list_ids


    # return model, model.model_type 
    # self.model, self.model_type = self._init_model(config) # 从训练配置中加载 模型 model
    # Init model(backbone+classifier) from the config dict 
    #   and load the best checkpoint, then parallel if necessary .
    emb_func = get_instance(arch, "backbone", config) # 获得一个rfs实例 backbone 提取特征 emb_func
    model_kwargs = {
        "way_num": config["way_num"],
        "shot_num": config["shot_num"] * config["augment_times"],
        "query_num": config["query_num"],
        "test_way": config["test_way"],
        "test_shot": config["test_shot"] * config["augment_times"],
        "test_query": config["test_query"],
        "emb_func": emb_func,
        "device": device,
    } # 参数
    model = get_instance(arch, "classifier", config, **model_kwargs) # 获得一个rfs实例 classifier 分类 model
    logger.info(model)
    logger.info("Trainable params in the model: {}".format(count_parameters(model)))
    logger.info("load the state dict from {}.".format(state_dict_path))
    state_dict = torch.load(state_dict_path, map_location="cpu") # 模型load
    model.load_state_dict(state_dict) # load model
    model = model.to(device) # 部署 model 配置  包含了emb_func
    if len(list_ids) > 1:
        parallel_list = config["parallel_part"]
        if parallel_list is not None:
            for parallel_part in parallel_list:
                if hasattr(model, parallel_part):
                    setattr(
                        model,
                        parallel_part,
                        nn.DataParallel(
                            getattr(model, parallel_part),
                            device_ids=list_ids,
                        ),
                    )
    model_type = model.model_type # 配置 model_type


    # The test stage. 配置完毕,开始测试
    '''
    读取某一训练文件夹,配置好 : path  config  emb_func  model
    读取支持集图片,标签ST
    提取每类的图片特征SF(emb_func),取平均值作为这一类的特征,存入dict
    读取查询图片,多张可以for读取
    提取查询图特征(emb_func)QF
    微调 classifier = self.set_forward_adaptation(SF, ST) # 支持集 部署微调? 得到一个网络
    处理QF,进行预测 output = classifier.predict(QF) # 微调过的网络 调用预测
    输出预测 对应种类名称
    ''' 
    # 设置为预测模式
    model.eval() # switch to evaluate mode
    model.reverse_setting_info() # 反转
    '''
    在模型中,我们通常会加上Dropout层和batch normalization层,
    在模型预测阶段,我们需要将这些层设置到预测模式,
    model.eval()就是帮我们一键搞定的,
    如果在预测的时候忘记使用model.eval(),会导致不一致的预测结果。
    '''
    # 关闭梯度计算
    if model_type == ModelType.METRIC:
        enable_grad = False
    else:
        enable_grad = True
    with torch.set_grad_enabled(enable_grad): # 梯度计算
        # 配置loader 用于读取图片
        loader = default_loader

        # 配置transforms 用于图片形状处理
        trfms = None
        trfms_list = [] 
        if config["image_size"] == 224:
            trfms_list.append(transforms.Resize((256, 256)))
            trfms_list.append(transforms.CenterCrop((224, 224)))
        elif config["image_size"] == 84:
            trfms_list.append(transforms.Resize((96, 96)))
            trfms_list.append(transforms.CenterCrop((84, 84)))
        # for MTL -> alternative solution: use avgpool(ks=11)
        elif config["image_size"] == 80:
            trfms_list.append(transforms.Resize((92, 92)))
            trfms_list.append(transforms.CenterCrop((80, 80)))
        else:
            raise RuntimeError
        trfms_list.append(transforms.ToTensor())
        trfms_list.append(transforms.Normalize(mean=MEAN, std=STD))
        trfms = transforms.Compose(trfms_list) # 设置transforms

        # 读取支持集 数据集 
        # 生成 数据,标签号,标签号-名称字典
        support_csv = os.path.join(config["data_root"], "support.csv")
        image_list = []
        label_list = []
        class_label_dict = dict() # {}
        with open(support_csv) as f_csv:
            f_support = csv.reader(f_csv, delimiter=",")
            for row in f_support:
                if f_support.line_num == 1: # 跳过表头 filename,label
                    continue
                image_name, image_class = row # 获取 filename,label
                if image_class not in class_label_dict: # 构建 名称-标签号 字典
                    class_label_dict[image_class] = len(class_label_dict) # 排序,从 0 重新给 label(image_class) 编号命名
                image_label = class_label_dict[image_class] # 类别名英文字符串 变成 012345的序号
                image_list.append(image_name)
                label_list.append(image_label)
        # print('============image_list==============\n', image_list, '\n')
        # print('============label_list==============\n', label_list, '\n')
        print('============class_label_dict==============\n', class_label_dict, '\n')

        # 读入所有支持集图片、标签
        # 转成tensor
        images = []
        labels = []
        for f, i in enumerate(image_list):
            image_path = os.path.join(config["data_root"], "images", i)
            # print(i)
            data = loader(image_path) # 将一张图读取为PIL Image.open(, 'RGB') 
            # print('========!!!!!!!========以下为 data')
            # print(data) #  <PIL.Image.Image image mode=RGB size=660x491 at 0x7F14758B4F28>
            # print('========!!!!!!!========以上为 data')
            if trfms is not None:
                data = trfms(data) # 有问题
                # print(data.shape) # torch.Size([3, 84, 84])
                # Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], 
                # but got 3-dimensional input of size [3, 84, 84] instead
                # print('========!!!!!!!========以下为 data2')
                # print(data) # 过了trfms,变成rgb的tensor
                # print('========!!!!!!!========以上为 data2')
            label = label_list[f] # 对应标签号
            # print(label)
            images.append(data)
            labels.append(label)
        images_tensor = torch.stack(images) # 必须变4维 # 原逻辑需往emb_func传入[20,3,84,84]
        ST = torch.tensor(labels) # SF 支持集标签 必须变为tensor 且升一维
        
        # 获取SF
        # 基于支持集进行微调
        # 生成分类器
        images_tensor = images_tensor.to(device) # 部署设备
        # print('========!!!!!!!========以下为 data3')
        # print(image) #  部署到cpu 还是rgb
        # print(image.shape)
        # print('========!!!!!!!========以上为 data3')
        with torch.no_grad(): # 在实际代码中,在预测阶段,也会加上torch.no_grad()来关闭梯度的计算
            support_feat = model.emb_func(images_tensor) # 局限:20张一起 2张一起 4维 至少得[1,3,84,84]
            # 获取特征 emb_func在model = get_instance中完成初始化 
        # print('========!!!!!!!========以下为 feat')
        # print(support_feat) # 与原代码逻辑输出一致
        # print('========!!!!!!!========以上为 feat')
        # sys.exit(0)
        SF = support_feat # SF 支持集特征
        # print('SF:', SF)
        # print('ST:', ST)
        classifier = model.set_forward_adaptation(SF, ST) # 支持集 微调 有问题 局限:要所有的类一起微调
        # print('========!!!!!!!========以下为 classifier')
        # print(classifier) # LogisticRegression(max_iter=1000, multi_class='multinomial', random_state=0)
        # print('========!!!!!!!========以上为 classifier')
        # ValueError: This solver needs samples of at least 2 classes in the data, but the data contains only one class: 0

        # 查询集(0、1 或 n 张图片)读取并提取特征
        # QF = support_feat[1] # 查询点 特征 有问题
        # Q_path = '/home/hhr/hrmnt/work/LibFewShot-main/FSL_Dataset/fgvc-aircraft-2013b/Q/'
        Q_image_Names = os.listdir(Q_path)
        Q_images = []
        for q in Q_image_Names: # 读取文件夹中所有图片
            # print(q, end=", ")
            Q_image_path = os.path.join(Q_path, q)
            data = loader(Q_image_path) # 将一张图读取为PIL Image.open(, 'RGB')
            if trfms is not None:
                data = trfms(data) # 有问题
            Q_images.append(data)
        Q_images_tensor = torch.stack(Q_images) # 必须变4维 # 原逻辑需往emb_func传入[20,3,84,84]
        Q_images_tensor = Q_images_tensor.to(device) # 部署设备
        with torch.no_grad(): # 在实际代码中,在预测阶段,也会加上torch.no_grad()来关闭梯度的计算
            Q_feat = model.emb_func(Q_images_tensor) # 局限:20张一起 2张一起 4维 至少得[1,3,84,84]
            # 获取特征 emb_func在model = get_instance中完成初始化 
        # print('========!!!!!!!========以下为 Q_feat')
        # print(Q_feat) # 与原代码逻辑输出一致
        # print('========!!!!!!!========以上为 Q_feat')
        # sys.exit(0)
        
        # Q 特征预测
        QF = Q_feat 
        # QF = torch.stack(QF)
        QF = F.normalize(QF, p=2, dim=1).detach().cpu().numpy() # 处理QF 
        output = classifier.predict(QF) # 微调过的网络 调用预测 直接输出
        # print('============output==============\n', output, '\n') # 输出分类结果

        # 转化 类别名称 list
        pre_classes = []
        for o in output:
            pre_classes.append([k for k,v in class_label_dict.items() if v == o][0])
        # print('============pre_classes==============\n', pre_classes, '\n')

        # 转化 图片:类别 dict
        pre_dict = {} 
        for i, pic in enumerate(Q_image_Names):
            pre_dict[pic] = pre_classes[i]
        pre_dict_order = sorted(pre_dict.items(), key = lambda x:x[1], reverse = False)  
        # 按字典集合中,每一个元组的第二个元素排列
        # x相当于字典集合中遍历出来的一个元组
        print('============pre_dict_order==============\n', pre_dict_order, '\n')

        # 计算预测的平均准确度
        pre_acc_t = 0 
        pre_acc_f = 0
        for k, v in pre_dict.items():
            # print(k, '\t',v)
            # print(k[:k.find('_')])
            # print(k[:k.find('_')] == v)
            if k[:k.find('_')] == v:
                pre_acc_t += 1
            else:
                pre_acc_f += 1
        pre_acc_num = pre_acc_t + pre_acc_f
        pre_acc_avg = pre_acc_t / pre_acc_num # 计算平均准确度
        print('本次测试', pre_acc_num, '个')
        print('测试平均准确率为 ' + str(pre_acc_avg) + '%')

    model.reverse_setting_info() # 模型反转回去
    
举报

相关推荐

0 条评论