应用背景介绍
由于高分遥感影像,同物异谱和同谱异物现象较为严重,尤其针对那种没有固定的几何形态、只有纹理特征的地类进行建模时,其提取结果较差,误提率较高。从人的视角来看,易混淆的地类之间其纹理相似,某些情况下,我们也需要结合周边的场景信息才能准确判断出该地类的类别,这对于人工智障来说这是个极大的挑战,因此,引入场景识别模型进行检测框的过滤是个不错的选择。
方法
样本准备
负样本准备
根据目标识别结果,挑选出误识别的区域。
正样本准备
利用目标识别样本,以所绘制检测框的中心向外延伸N个长度,自动进行裁剪。
下面的代码是以shapefile格式的标签来开展的
gen_pos.py
import os
import cv2
import numpy as np
import sys
sys.path.append('data')
from shp2imagexy import *
import glob
import matplotlib.pyplot as plt
def bbox_to_rect(bbox, color): # 本函数已保存在d2lzh包中方便以后使用
# 将边界框(左上x, 左上y, 右下x, 右下y)格式转换成matplotlib格式:
# ((左上x, 左上y), 宽, 高)
return plt.Rectangle(
xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
fill=False, edgecolor=color, linewidth=2)
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=True, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 128), np.mod(dh, 128) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
if __name__ == '__main__':
imglist = glob.glob('D:/2021/7/em20210628/sheepfold2/*/*.tif')
for imgPath in imglist:
print(imgPath)
try:
imgName = os.path.split(imgPath)[-1].split('.')[0]
shpPath = imgPath.replace('tif', 'shp')
anns = shp2imagexy(imgPath, shpPath)
anns = [ann[:-1] for ann in anns]
boxes = np.array(anns, dtype=np.uint16)
img = cv2.imread(imgPath, cv2.IMREAD_LOAD_GDAL)
# show results
# fig = plt.imshow(img)
# for i, box in enumerate(boxes):
# rect = bbox_to_rect(box, 'red')
# fig.axes.add_patch(rect)
# fig.axes.text(rect.xy[0] + 24, rect.xy[1] + 10, "sheepfold",
# va='center', ha='center', fontsize=6, color='blue',
# bbox=dict(facecolor='m', lw=0))
# plt.show()
w, h = img.shape[:2]
for i, box in enumerate(boxes):
w0, h0 = (128 - (box[3] - box[1])) // 2, (128 - (box[2] - box[0])) // 2
crop = img[np.clip(box[1]-w0, 0, w):np.clip(box[3]+w0, 0, w),
np.clip(box[0]-h0, 0, w):np.clip(box[2]+h0, 0, w)]
crop, ratio, (dw, dh) = letterbox(crop, new_shape=(128, 128))
# crop = img[box[1]:box[3], box[0]:box[2]]
# plt.subplot(121)
# plt.imshow(crop)
# plt.subplot(122)
# plt.imshow(img)
# plt.show()
savePath = os.path.join('pos', f'{imgName}_{i}.tif')
cv2.imwrite(savePath, crop)
except:
continue
shp2imagexy.py
# -*- coding: utf-8 -*-
from osgeo import ogr
from osgeo import gdal
from osgeo import osr
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import math
def getSRSPair(dataset):
'''
获得给定数据的投影参考系和地理参考系
:param dataset: GDAL地理数据
:return: 投影参考系和地理参考系
'''
prosrs = osr.SpatialReference()
prosrs.ImportFromWkt(dataset.GetProjection())
geosrs = prosrs.CloneGeogCS()
return prosrs, geosrs
def xy_to_coor(x, y):
lonlat_coordinate = []
L = 6381372 * math.pi*2
W = L
H = L/2
mill = 2.3
lat = ((H/2-y)*2*mill)/(1.25*H)
lat = ((math.atan(math.exp(lat))-0.25*math.pi)*180)/(0.4*math.pi)
lon = (x-W/2)*360/W
# TODO 最终需要确认经纬度保留小数点后几位
lonlat_coordinate.append((round(lon,7),round(lat,7)))
return round(lon,7), round(lat,7)
def geo2lonlat(dataset, x, y):
'''
将投影坐标转为经纬度坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param x: 投影坐标x
:param y: 投影坐标y
:return: 投影坐标(x, y)对应的经纬度坐标(lon, lat)
'''
prosrs, geosrs = getSRSPair(dataset)
ct = osr.CoordinateTransformation(prosrs, geosrs)
coords = ct.TransformPoint(x, y)
return coords[:2]
def lonlat2geo(dataset, lon, lat):
'''
将经纬度坐标转为投影坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param lon: 地理坐标lon经度
:param lat: 地理坐标lat纬度
:return: 经纬度坐标(lon, lat)对应的投影坐标
'''
prosrs, geosrs = getSRSPair(dataset)
ct = osr.CoordinateTransformation(geosrs, prosrs)
coords = ct.TransformPoint(lon, lat)
return coords[:2]
def imagexy2geo(dataset, col, row):
'''
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
:param dataset: GDAL地理数据
:param row: 像素的行号
:param col: 像素的列号
:return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
'''
trans = dataset.GetGeoTransform()
print(trans)
print(row,col)
px = trans[0] + col * trans[1] + row * trans[2]
py = trans[3] + col * trans[4] + row * trans[5]
return px, py
def geo2imagexy01(dataset, x, y):
'''
根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
:param dataset: GDAL地理数据
:param x: 投影或地理坐标x
:param y: 投影或地理坐标y
:return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
'''
trans = dataset.GetGeoTransform()
# trans = dataset
a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
b = np.array([x - trans[0], y - trans[3]])
return np.linalg.solve(a, b)
def geo2imagexy(dataset, x, y):
'''
根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
:param dataset: GDAL地理数据
:param x: 投影或地理坐标x
:param y: 投影或地理坐标y
:return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
'''
trans = dataset.GetGeoTransform()
#a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
#b = np.array([x - trans[0], y - trans[3]])
#return np.linalg.solve(a, b) # 使用numpy的linalg.solve进行二元一次方程的求解
dTemp = trans[1] * trans[5] - trans[2] * trans[4]
Xpixel = (trans[5] * (x - trans[0]) - trans[2] * (y - trans[3])) / dTemp
Yline = (trans[1] * (y - trans[3]) - trans[4] * (x - trans[0])) / dTemp
return [Xpixel, Yline]
def shp2imagexy(imgPath, shpPath):
dataset = gdal.Open(imgPath)
ds = ogr.Open(shpPath, 1)
if ds is None:
print('Could not open folder')
in_lyr = ds.GetLayer()
lyr_dn = in_lyr.GetLayerDefn()
feature = in_lyr.GetNextFeature()
finalResult = []
while feature is not None:
geom = feature.geometry()
id = feature.GetField('cls')
arr = np.array(feature.GetGeometryRef().GetEnvelope())
# print('before', arr)
# coordsMin = lonlat2geo(dataset, arr[0], arr[3])
coordsMin = geo2imagexy(dataset, arr[0], arr[3])
# coordsMax = lonlat2geo(dataset, arr[1], arr[2])
coordsMax = geo2imagexy(dataset, arr[1], arr[2])
finalResult.append([coordsMin[0], coordsMin[1], coordsMax[0], coordsMax[1], id])
feature = in_lyr.GetNextFeature()
return finalResult
if __name__ == '__main__':
img_filename = 'E:/2_1/tc/0000000001_V1/0000000001.tif'
dst_filename = 'E:/2_1/tc/0000000001_V1/0000000001_V1_POLY.shp'
finalResult = shp2imagexy(img_filename, dst_filename)
img = cv.imread(img_filename, cv.IMREAD_LOAD_GDAL)
# finalResult = np.array(finalResult)
for bbox in finalResult:
# xmin = min(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
# ymin = min(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
# xmax = max(bbox[0][0], bbox[1][0], bbox[2][0], bbox[3][0])
# ymax = max(bbox[0][1], bbox[1][1], bbox[2][1], bbox[3][1])
# cv.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 100, 255), 5)
print(bbox)
cv.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 100, 255), 5)
plt.imshow(img)
plt.show()
生成的图片
模型训练
模型选用图像分类模型,例如vgg、resnet等,加载预训练模型参数后,对全连接层进行修改,输出通道数改为类别数。
模型预测
见源码
源码地址
https://github.com/SonwYang/RemoteSensingSceneRecognition