0
点赞
收藏
分享

微信扫一扫

test pytorch 测试模型绘图到图片

暮晨夜雪 2022-01-17 阅读 71
python
#coding=utf-8
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pandas as pd
from model import myconvNet
import os
import cv2
import time
from dataloader import tempDataset

import os
import yaml
import matplotlib.pyplot as plt
from importlib.abc import Loader
from PIL import Image
import cv2
import glob
import numpy as np
from torchvision import transforms


def transfunc(image):
        trans = transforms.Compose([
                transforms.Resize([224,224]),
                transforms.ToTensor(),
            ])
        return trans(image)

def test_plot(data_path):
    # 加载模型
    data_path = data_path
    yaml_list = glob.glob(data_path + "/*/*.yaml")
    label_list = []
    for item in yaml_list:
        datalabel = yaml.load(open(item), Loader=yaml.FullLoader)
        for temp in datalabel['image_data']:
            # import pdb;pdb.set_trace()
            if len(temp['keypoints']) == 6:
                label_list.append(temp)
    num = len(label_list)
    for index in range(len(label_list)):
        image_path = data_path + '/' + label_list[index]['image_id']
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        try:
            point_1_x = label_list[index]['keypoints'][0]['image_coords']['u']
            point_1_y = label_list[index]['keypoints'][0]['image_coords']['v']
            point_2_x = label_list[index]['keypoints'][1]['image_coords']['u']
            point_2_y = label_list[index]['keypoints'][1]['image_coords']['v']
            point_3_x = label_list[index]['keypoints'][2]['image_coords']['u']
            point_3_y = label_list[index]['keypoints'][2]['image_coords']['v']
            point_4_x = label_list[index]['keypoints'][3]['image_coords']['u']
            point_4_y = label_list[index]['keypoints'][3]['image_coords']['v']
            point_5_x = label_list[index]['keypoints'][4]['image_coords']['u']
            point_5_y = label_list[index]['keypoints'][4]['image_coords']['v']
            point_6_x = label_list[index]['keypoints'][5]['image_coords']['u']
            point_6_y = label_list[index]['keypoints'][5]['image_coords']['v']
        except:
            import pdb;pdb.set_trace()
        gt = np.array([point_1_x, point_1_y, point_2_x, point_2_y, point_3_x, point_3_y, point_4_x, point_4_y, point_5_x, point_5_y, point_6_x, point_6_y])
        # gt = gt / np.array([640, 480, ])
        vis_demo_img = Image.fromarray(image)
        net = myconvNet(nb_out=12)
        net.float().cuda()
        net.eval()
        net.load_state_dict(torch.load('../save_model/Iter_100_myconvnet.pt'))
        test_image = transfunc(vis_demo_img)
        test_image = test_image.unsqueeze(0).cuda()
        # import pdb;pdb.set_trace()
        pred_points = net(test_image)
        pred_points = pred_points.cpu().data.numpy()
        varvalue = np.array([[640, 480, 640, 480, 640, 480, 640, 480, 640, 480, 640, 480]])
        pred_points = np.multiply((pred_points + 0.5), varvalue)
        test_image = test_image.cpu().data.numpy()
        test_image = (test_image * 255.).astype(np.uint8)
        plt.imshow(vis_demo_img, cmap='gray')
        # import pdb;pdb.set_trace()
        # plt.scatter(gt[::2],gt[1::2], c = '#00CED1') # blue
        plt.scatter(pred_points[0][::2] ,pred_points[0][1::2], c = '#DC143C') # red
        filepath = "./plot_result/{}_{}_result.png".format(index, num)
        plt.savefig(filepath)
        # plt.show()
        plt.clf()
        print("have saved ./plot_result/{}_{}_result.png".format(index, num))

if __name__ == "__main__":
    data_path = "../data_labelled/train"
    test_plot(data_path=data_path)
举报

相关推荐

0 条评论