1.假设GT框的数据表示为下图:
2.encode 代码:
import torch
boxes = torch.tensor([[89,100, 387, 284],[23, 33, 110, 287]],dtype=torch.float32)
label = torch.tensor([6,14],dtype=torch.long)
w = 448.0
h = 448.0
boxes /= torch.Tensor([[w, h, w, h]]).expand_as(boxes)
def encode(boxes, labels):
""" Encode box coordinates and class labels as one target tensor.
Args:
boxes: (tensor) [[x1, y1, x2, y2]_obj1, ...], normalized from 0.0 to 1.0 w.r.t. image width/height.
labels: (tensor) [c_obj1, c_obj2, ...]
Returns:
An encoded tensor sized [S, S, 5 x B + C], 5=(x, y, w, h, conf)
"""
S, B, C = 7, 2, 20
N = 5 * B + C
target = torch.zeros(S, S, N)
cell_size = 1.0 / float(S)#因为坐标已经归一化了,所以就用1除以S。w,h都视为1
boxes_wh = boxes[:, 2:] - boxes[:, :2] # width and height for each box, [n, 2]
boxes_xy = (boxes[:, 2:] + boxes[:, :2]) / 2.0 # center x & y for each box, [n, 2]
#trans x1,y1,x2,y2 to x,y,w,h
for b in range(boxes.size(0)):
xy, wh, label = boxes_xy[b], boxes_wh[b], int(labels[b])
ij = (xy / cell_size).ceil() - 1.0#减一是因为从0开始计算网格数
i, j = int(ij[0]), int(ij[1]) # y & x index which represents its location on the grid.#代表第几个grid
x0y0 = ij * cell_size # x & y of the cell left-top corner.
#因为从0开始计算,所以ij就是该点所在网格的前面有几个网格,然后乘以一个网格所占的高宽(cell_size)就是该点xy
#所在网格区域的左上角的坐标
xy_normalized = (xy - x0y0) / cell_size # x & y of the box on the cell, normalized from 0.0 to 1.0.
#再次让坐标相对于网格再做一次归一化操作
# TBM, remove redundant dimensions from target tensor.
# To remove these, loss implementation also has to be modified.
for k in range(B):
s = 5 * k
target[j, i, s :s+2] = xy_normalized
target[j, i, s+2:s+4] = wh
target[j, i, s+4 ] = 1.0#有物体为1,没物体为0
target[j, i, 5*B + label] = 1.0#target列数,类别相当于独热,索引减10代表第几类
return target
tagert = encode(boxes,label)
print(tagert)