0
点赞
收藏
分享

微信扫一扫

基于mmdection、Labelme、mask-rcnn的半自动标注

我阿霆哥 2022-04-14 阅读 91

github项目地址:

https://github.com/o0stinger0o/semi_automatic_labelingicon-default.png?t=M3C8https://github.com/o0stinger0o/semi_automatic_labeling

首先在mmdection上基于已有数据集训练mask-rcnn模型,可基于不同的实例分割模型或检测模型,笔者这里训练的是一个汽车部件的分割模型:

 模型训练的配置文件在work_dir下。

主要用于半自动标注文件有两个semi_automatic_labeling.py和labelme2coco2.py。

其中semi_automatic_labeling.py用于将模型检测出的bbox和segment转化为labelme标注格式并存储为json格式。

labelme2coco2.py用于将labelme标注转化为coco格式。

代码如下

semi_automatic_labeling.py

import os.path
import labelme.label_file as lf
import base64
import json
from PIL import Image
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
from skimage import measure
import numpy as np
import cv2

classes = [
'front_left_window',
'front_right_window',
'back_left_window',
'back_right_window',
'back_bumper',
'back_glass',
'back_left_door',
'back_left_light',
'back_right_door',
'back_right_light',
'front_bumper',
'front_glass',
'front_left_door',
'front_left_light',
'front_right_door',
'front_right_light',
'hood',
'left_mirror',
'right_mirror',
'wheel',
'tailgate',
'left_body',
'right_body',
'bus_left_window',
'bus_right_window',
'left_mid_window',
'right_mid_window',
'back_body',
'trunk',
] #要识别的类29类


IMA_PATH = './imgs/' #图片及标注json存储地址

# Specify the path to model config and checkpoint work_dir
config_file = 'work_dir/mask_rcnn_r101_fpn_mstrain-poly_3x_coco_tag.py'
checkpoint_file = 'work_dir/epoch_12.pth'

# build the model from a config work_dir and a checkpoint work_dir
model = init_detector(config_file, checkpoint_file, device='cuda:0')


file_name_list = os.listdir(IMA_PATH)

def get_height_width(img_path):
    img = Image.open(img_path)
    return img.height,img.width


def get_img_json(shapes,imagePath=None,imageData=None,imageHeight=512,imageWidth=512):
    article_info = {}
    data = json.loads(json.dumps(article_info))

    data['version'] = '4.6.0'
    data['flags'] = {}
    data['shapes'] = shapes
    data['imagePath'] = imagePath
    data['imageData'] = imageData
    data['imageHeight'] = imageHeight
    data['imageWidth'] = imageWidth


    return data

def get_shape(label,points,shape_type='polygon',group_id=None):
    article_info = {}
    data = json.loads(json.dumps(article_info))

    data['label'] = label
    data['points'] = points
    data['group_id'] = group_id
    data['shape_type'] = shape_type
    data['flags'] = {}


    return data




SCORE_THR = 0.9 #目标检测score threshold
FIND_CONTOURS_LEVEL = 0.5 #边缘提取等级

if __name__ == '__main__':


    for file in file_name_list:
        print('process '+file)

        img_path = os.path.join(IMA_PATH,file)
        result = inference_detector(model, img_path) #检测
        result_bbox = result[0] #bbox有29类
        result_segment_binary = result[1] # TRUE/FALSE的分割掩码29类



        result_dir_list = []

        for class_num in range(len(result_bbox)):
            result_dir = {}
            result_dir['class_name'] = classes[class_num]
            result_dir['bbox'] = []
            result_dir['segment'] = []
            for instance in range(len(result_bbox[class_num])):

                if result_bbox[class_num][instance][4] >= SCORE_THR: 
                    result_dir['bbox'].append(result_bbox[class_num][instance][:4])

                    contours = measure.find_contours(np.array(result_segment_binary[class_num][instance]).astype(int).T,
                                                     FIND_CONTOURS_LEVEL)#寻找mask

                    contours = cv2.convexHull(contours[0].astype(np.int32))#寻找mask的凸包

                    result_dir['segment'].append(contours.squeeze())

            if len(result_dir['bbox']) >= 1:
                result_dir_list.append(result_dir)

        shape_list = []
        for res_dir in result_dir_list:
            for bbox in res_dir['bbox']:
                shape = get_shape(res_dir['class_name'],bbox.reshape(2,2).tolist(),'rectangle')
                shape_list.append(shape)
            for segment  in  res_dir['segment']:
                shape = get_shape(res_dir['class_name'],segment.tolist(),'polygon')
                shape_list.append(shape)


        imageData = lf.LabelFile.load_image_file(img_path)
        imageData = base64.b64encode(imageData).decode("utf-8")
        imageHeight,imageWidth = get_height_width(img_path)

        img_json = get_img_json(shape_list,file,imageData,imageHeight,imageWidth)
        json_path = os.path.join(IMA_PATH,file.split('.')[0]+'.json')
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(img_json, f, ensure_ascii=False)






labelme2coco2.py

#!/usr/bin/env python

import argparse
import collections
import datetime
import glob
import json
import os
import os.path as osp
import sys
import uuid

import imgviz
import numpy as np

import labelme

try:
    import pycocotools.mask
except ImportError:
    print("Please install pycocotools:\n\n    pip install pycocotools\n")
    sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("input_dir", help="input annotated directory")
    parser.add_argument("output_dir", help="output dataset directory")
    parser.add_argument("--labels", help="labels file", required=True)
    parser.add_argument(
        "--noviz", help="no visualization", action="store_true"
    )
    args = parser.parse_args()

    if osp.exists(args.output_dir):
        print("Output directory already exists:", args.output_dir)
        sys.exit(1)
    os.makedirs(args.output_dir)
    os.makedirs(osp.join(args.output_dir, "JPEGImages"))
    if not args.noviz:
        os.makedirs(osp.join(args.output_dir, "Visualization"))
    print("Creating dataset:", args.output_dir)

    now = datetime.datetime.now()

    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[dict(url=None, id=0, name=None,)],
        images=[
            # license, url, file_name, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    class_name_to_id = {}
    for i, line in enumerate(open(args.labels).readlines()):
        class_id = i - 1  # starts with -1
        class_name = line.strip()
        if class_id == -1:
            assert class_name == "__ignore__"
            continue
        class_name_to_id[class_name] = class_id
        data["categories"].append(
            dict(supercategory=None, id=class_id, name=class_name,)
        )

    out_ann_file = osp.join(args.output_dir, "annotations.json")
    label_files = glob.glob(osp.join(args.input_dir, "*.json"))
    for image_id, filename in enumerate(label_files):
        print("Generating dataset from:", filename)

        label_file = labelme.LabelFile(filename=filename)

        base = osp.splitext(osp.basename(filename))[0]
        out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg")

        img = labelme.utils.img_data_to_arr(label_file.imageData)
        imgviz.io.imsave(out_img_file, img)
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=osp.relpath(out_img_file, osp.dirname(out_ann_file)),
                height=img.shape[0],
                width=img.shape[1],
                date_captured=None,
                id=image_id,
            )
        )

        masks = {}  # for area
        segmentations = collections.defaultdict(list)  # for segmentation
        for shape in label_file.shapes:
            points = shape["points"]
            label = shape["label"]
            group_id = shape.get("group_id")
            shape_type = shape.get("shape_type", "polygon")
            mask = labelme.utils.shape_to_mask(
                img.shape[:2], points, shape_type
            )

            if group_id is None:
                group_id = uuid.uuid1()

            instance = (label, group_id)

            if instance in masks:
                masks[instance] = masks[instance] | mask
            else:
                masks[instance] = mask

            if shape_type == "rectangle":
                (x1, y1), (x2, y2) = points
                x1, x2 = sorted([x1, x2])
                y1, y2 = sorted([y1, y2])
                points = [x1, y1, x2, y1, x2, y2, x1, y2]
            if shape_type == "circle":
                (x1, y1), (x2, y2) = points
                r = np.linalg.norm([x2 - x1, y2 - y1])
                # r(1-cos(a/2))<x, a=2*pi/N => N>pi/arccos(1-x/r)
                # x: tolerance of the gap between the arc and the line segment
                n_points_circle = max(int(np.pi / np.arccos(1 - 1 / r)), 12)
                i = np.arange(n_points_circle)
                x = x1 + r * np.sin(2 * np.pi / n_points_circle * i)
                y = y1 + r * np.cos(2 * np.pi / n_points_circle * i)
                points = np.stack((x, y), axis=1).flatten().tolist()
            else:
                points = np.asarray(points).flatten().tolist()

            segmentations[instance].append(points)
        segmentations = dict(segmentations)

        for instance, mask in masks.items():
            cls_name, group_id = instance
            if cls_name not in class_name_to_id:
                continue
            cls_id = class_name_to_id[cls_name]

            mask = np.asfortranarray(mask.astype(np.uint8))
            mask = pycocotools.mask.encode(mask)
            area = float(pycocotools.mask.area(mask))
            bbox = pycocotools.mask.toBbox(mask).flatten().tolist()

            data["annotations"].append(
                dict(
                    id=len(data["annotations"]),
                    image_id=image_id,
                    category_id=cls_id,
                    segmentation=segmentations[instance],
                    area=area,
                    bbox=bbox,
                    iscrowd=0,
                )
            )

        if not args.noviz:
            viz = img
            if masks:
                labels, captions, masks = zip(
                    *[
                        (class_name_to_id[cnm], cnm, msk)
                        for (cnm, gid), msk in masks.items()
                        if cnm in class_name_to_id
                    ]
                )
                viz = imgviz.instances2rgb(
                    image=img,
                    labels=labels,
                    masks=masks,
                    captions=captions,
                    font_size=15,
                    line_width=2,
                )
            out_viz_file = osp.join(
                args.output_dir, "Visualization", base + ".jpg"
            )
            imgviz.io.imsave(out_viz_file, viz)

    with open(out_ann_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    main()
举报

相关推荐

0 条评论