一、背景
BEV方案中,将图像视角转换到BEV视角的方法对模型性能影响较大,FastBEV的速度较快,但投影效果上限不高,LSS投影上限较高,但速度较慢 (耗时相对较高)。是否有折中的方案,在耗时增加相对较少的情况下,提升模型的上限(中高算力平台下,提升模型能力)?
二、视角转换关键算子-----gridsample
这是pytorch官网对gridsample算子使用方法说明,其支持4-D(FastBEV/IMP)和5-D(LSS)采样,将图像特征提取到对应的BEV特征中,完成相机视角转换:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
 
 5-D gridsample相比4-D gridsample耗时剧增,假如在某智驾芯片上,4-D gridsample耗时是2ms,相同条件下5-D gridsample的耗时可能是200ms(具体耗时受特征图通道数影响),这种耗时急剧上升的方案,很难在智驾中落地应用。
三、LSS投影优化
1.先来对比4-D gridsample和5-D gridsample的输入输出关系:
4-D gridsample
 input: (N, C, H_in, W_in);
 bev_grid: (N, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
 output: (N, C, H_out, W_out)
5-D gridsample
 input: (N, C, H_in, W_in);
 for循环提取每个C通道的输入特征进行softmax处理input_i:(N, D, H_in, W_in),按照dim=1堆叠起来,得到深度输入input_2:(N, C, D, H_in, W_in), 这里的D表示深度估计的通道数;
 bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
 output: (N, C, Z_out, H_out, W_out);
由于获取深度信息需要用到5-D gridsample,想要降低耗时,考虑减少特征图通道对耗时的影响,即做5-D gridsample时,将通道C设为1;
2.具体方法-----拆解5-D gridsample
将5-D gridsample拆解为一个4-D gridsample和一个单通道(C=1)的5-D gridsample,4-D gridsample负责提取多通道特征信息,单通道5-D gridsample负责提取深度特征信息,最后将两个特征信息相乘,得到多通道下的深度信息,等效变换过程如下:
step1:
4-D gridsample
 input: (N, C, H_in, W_in);
 bev_grid: (N, Z_out, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
 for循环提取每个Z_out下的bev_grid_i: (N, Z_out, H_out, W_out, 2),通过4-D gridsample分别得到输出特征图output_i: (N, C, H_out, W_out),按照dim=2堆叠起来,得到最终的BEV特征图output_1(没有深度概率信息):
 output_1: (N, C, Z_out, H_out, W_out)
step2:
单通道5-D gridsample
 input: (N, C, H_in, W_in);
 input经过softmax处理后的特征图input_2: (N, D, H_in, W_in),这里的D表示深度估计的通道数;将input_2在dim=1上扩展一个维度,得到input_3:(N, 1, D, H_in, W_in)
 bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
 output_2: (N, 1, Z_out, H_out, W_out);
step3:
将output_1和output_2相乘得到有深度概率信息的BEV特征图
 output = outptu_1 * output_2 = (N, C, Z_out, H_out, W_out) * (N, 1, Z_out, H_out, W_out) = (N, 1, Z_out, H_out, W_out)
四、部分代码
1.IPM的BEV网格坐标索引
class UpdateIndicesIPM:
    def __init__(self, height, range, voxel_size, feature_size, downsample):
        self.height = height
        self.range = range
        self.voxel_size = voxel_size
        self.feature_size = feature_size
        self.ds_matrix = np.eye(4)
        self.ds_matrix[:2] /= downsample
    def __call__(self, data):
        num = len(data["cam2egoes"])
        ego2feats = torch.zeros((num, 4, 4), dtype=torch.float32)
        for i in range(num):
            ego2cam = np.linalg.inv(data["cam2egoes"][i])
            tmp = np.eye(4)
            tmp[:3, :3] = data["cam_intrinsics"][i]
            ego2feats[i] = torch.tensor(self.ds_matrix @ tmp @ ego2cam)
        
        grid = torch.stack(torch.meshgrid([
                torch.arange(self.range[0], self.range[3], self.voxel_size[0]),
                torch.arange(self.range[1], self.range[4], self.voxel_size[1]),
                torch.tensor(self.height), torch.tensor(1.0)
                ], indexing="ij")) # [4, 188, 64, 4, 1]
        grid_h, grid_w = grid.shape[1:3]
        grid = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512] 
        points_2d = torch.bmm(ego2feats[:, :3, :], grid)
        x = (points_2d[:, 0] / points_2d[:, 2]).round().long()  
        y = (points_2d[:, 1] / points_2d[:, 2]).round().long() 
        z = points_2d[:, 2]
        valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) & 
                  (y < self.feature_size[0]) & (z > 0))
        x[valid] = 0
        y[valid] = 0
        x = (x.float() / self.feature_size[1] * 2.) - 1.0
        y = (y.float() / self.feature_size[0] * 2.) - 1.0
        indices = torch.cat([x.unsqueeze(2), y.unsqueeze(2)], dim=2)
        indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 2) # batch, num_img, bev_w, bev_h, num_height, 2
        data["indices"] = indices
        return data
2.FastBEV
class FastBevTransform(nn.Module):
    def __init__(self, feats_channels, num_height):
        super().__init__()
        self._num_height = num_height
        self._conv = nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1)
        self._grid_sample = GridSample(mode="nearest",
                                            padding_mode="zeros",
                                            align_corners=True)
        self._cat = Concat(dim=1)
    def forward(self, feats, indices):
        # feats: (7B, C, H, W), indices: (7B, Hg, Wg, Z, 2)
        bev_feats = []
        for i in range(self._num_height):
            output = self._grid_sample(feats, indices[:,:,:,i])
            bev_feats.append(output)
        bev_feats = self._cat(bev_feats)  # (7B, Z*C, Hg, Wg)
        bev_feats = self._conv(bev_feats)  # (7B, C, Hg, Wg)
        return bev_feats
3.LSS的BEV网格坐标索引
class UpdateIndicesLSS:
    def __init__(self, height, range, voxel_size, feature_size,
                 resolution, max_num_depth, downsample):
        self.height = height
        self.range = range
        self.voxel_size = voxel_size
        self.feature_size = feature_size
        self.resolution = resolution
        self.max_num_depth = max_num_depth
        self.ds = np.eye(3)
        self.ds[:2] /= downsample
    
    def __call__(self, data):
        num = len(data["cam2egoes"])
        ego2cams = torch.zeros((num, 4, 4), dtype=torch.float32)
        cam2feats = torch.zeros((num, 3, 3), dtype=torch.float32)
        for i in range(num):
            ego2cams[i] = torch.tensor(np.linalg.inv(data["cam2egoes"][i]))
            cam2feats[i] = torch.tensor(self.ds @ data["cam_intrinsics"][i])
        grid = torch.stack(torch.meshgrid([
                torch.arange(self.range[0], self.range[3], self.voxel_size[0]),
                torch.arange(self.range[1], self.range[4], self.voxel_size[1]),
                torch.tensor(self.height), torch.tensor(1.0)
                ], indexing="ij")) # [4, 188, 64, 4, 1]
        grid_h, grid_w = grid.shape[1:3]
        grid4 = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512] 
        points_2d = torch.bmm(ego2cams[:, :3, :], grid4)
        x = (points_2d[:, 0] / points_2d[:, 2])   # [7, 48128]
        y = (points_2d[:, 1] / points_2d[:, 2])   # [7, 48128]
        z = points_2d[:, 2]                       # [7, 48128]
        r = points_2d.norm(dim=1)                 # [B*N, Hg*Wg]
        d = torch.floor(r / self.resolution)
        distortions = torch.tensor(np.array(data["cam_distortions"]).T)
        k1,k2,k3,p1,p2,k4,k5,k6 = distortions[:,:,None]
        fovs = torch.tensor(data['crop_fovs']).unsqueeze(-1) / 2.0
        in_fov = np.abs(np.arctan2(points_2d[:, 0], z)) < fovs
        r2 = x**2 + y**2
        ratio = (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) / (1 + k4 * r2 + k5 * r2**2 + k6 * r2**3)
        x_undist = x * ratio + 2 * p1 * x * y + p2 * (r2 + 2 * x**2)
        y_undist = y * ratio + p1 * (r2 + 2 * y**2) + 2 * p2 * x * y
        x = cam2feats[:, 0, [0]] * x_undist + cam2feats[:, 0, [2]]
        y = cam2feats[:, 1, [1]] * y_undist + cam2feats[:, 1, [2]]
        valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) & \
                  (y < self.feature_size[0]) & (z > 0) & in_fov & \
                  (d >= 0) & (d < self.max_num_depth))   # [7, 48128]
        x[valid], y[valid], d[valid] = -1, -1, -1
        x = (x.float() / self.feature_size[1] * 2.) - 1.0
        y = (y.float() / self.feature_size[0] * 2.) - 1.0
        d = (d.float() / self.max_num_depth * 2.) - 1.0
        indices = torch.cat([x[:,:,None], y[:,:,None], d[:,:,None]], dim=2)  # [7, 48128, 3]
        indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 3)   # batch*num_img, bev_w, bev_h, num_height, 3(x, y, d)
        data["indices"] = indices.permute(0, 3, 1, 2, 4) # batch*num_img, num_height, bev_w, bev_h, 3(x, y, d)
        return data
4.LSS的BEV投影
class LssBevTransform(nn.Module):
    def __init__(self, feats_channels, num_height, max_num_depth):
        super().__init__()
        self._num_height = num_height
        self._max_num_depth = max_num_depth
        self.ms_cam = MS_CAM(feats_channels * num_height)
        self._depth_proj = nn.Sequential(
            nn.Conv2d(feats_channels, max_num_depth, kernel_size=3, padding=1),
            nn.Softmax(dim=1)
        )
        self._grid_sample = GridSample(mode="nearest",
                                            padding_mode="zeros",
                                            align_corners=True)
        self._cat = Concat(dim=1)
        self._blocks = nn.Sequential(
            nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1),
            nn.BatchNorm2d(feats_channels),
            nn.ReLU(inplace=True)
        )
        
    def simplify_bev(self, feats, indices):
        depths = self._depth_proj(feats)[:, None]
        import ipdb
        ipdb.set_trace()
        pass
    def forward(self, feats, indices):
        # feats: (B*N, C, H, W)
        # indices: (B*N, Z, X, Y, 3) where 3 dims represent (w, h, d).
        bev_feats = self._sample_bev_feats(feats, indices[..., :2])  # (B*N, C, Z, X, Y)
        depth_feats = self._sample_depth_feats(feats, indices)  # (B*N, 1, Z, X, Y)
        final_feats = bev_feats * depth_feats  # (B*N, C, Z, Y, X)
        N, C, Z, Y, X = final_feats.shape
        final_feats = final_feats.view(N, C * Z, Y, X)  # (B*N, Z*C, Hg, Wg)
        final_feats = final_feats*self.ms_cam(final_feats)
         
        final_feats = self._blocks(final_feats)  # (B*N, C, Hg, Wg)        
        return final_feats
    def _sample_bev_feats(self, feats, indices):
        bev_feats = [self._grid_sample(feats, indices[:, i]) for i in range(self._num_height)]
        return torch.stack(bev_feats, dim=2)  # (B*N, C, Z, Y, X)  
    def _sample_depth_feats(self, feats, indices):
        depths = self._depth_proj(feats)[:, None]  # (B*N, 1, D, H, W)
        return self._grid_sample(depths, indices) # (B*N, 1, Z, X, Y)
五、展望
LSS投影时将input_3:(N, 1, D, H_in, W_in)中D和H_in进行reshape合并后得(N, 1, D*H_in, W_in),可以完全通过4-D gridsample提取特征,耗时进一步降低,等效替代测试代码如下:
#!/usr/bin/env python3
import unittest
import torch
import torch.nn.functional as F
class GridSampleTest(unittest.TestCase):
    def test_grid_sample_equivalence(self):
        D, H, W = 100, 144, 256
        Y, X = 64, 128
        # Generate random features.
        feats_5d = torch.randn(1, 1, D, H, W)
        # Generate random indices.
        d = torch.randint(high=D, size=(Y, X))
        h = torch.randint(high=H, size=(Y, X))
        w = torch.randint(high=W, size=(Y, X))
        # Prepare grid for 5D grid_sample.
        indices_5d = torch.stack([
            2.0 * w / (W - 1) - 1.0,
            2.0 * h / (H - 1) - 1.0,
            2.0 * d / (D - 1) - 1.0
        ], dim=-1).view(1, 1, Y, X, 3)
        bev_feats_5d = F.grid_sample(
            feats_5d, indices_5d, mode="nearest", align_corners=True
        ).view(Y, X)
        # Flatten D and H dimensions and prepare grid for 4D grid_sample.
        dh = d * H + h
        indices_4d = torch.stack([
            2.0 * w / (W - 1) - 1.0,
            2.0 * dh / (D * H - 1) - 1.0
        ], dim=-1).view(1, Y, X, 2)
        feats_4d = feats_5d.view(1, 1, D * H, W)
        bev_feats_4d = F.grid_sample(
            feats_4d, indices_4d, mode="nearest", align_corners=True
        ).view(Y, X)
        # Check if the results are close.
        self.assertTrue(torch.allclose(bev_feats_5d, bev_feats_4d, atol=1e-6))
if __name__ == "__main__":
    unittest.main()










