0
点赞
收藏
分享

微信扫一扫

点云3D目标检测学习(2):pointnet++源码

彭维盛 2022-03-11 阅读 53

pointnet++源码学习

pointnet++中包含pointnet源码,因此只看pointnet++就可以了

整体流程

以debug的形式看源码,分析电源数据在网络中的变化过程,如何进行特征提取
以batch size = 2 为例

输入数据: 2, 1024, 3
->channel_first 2, 3, 1024

第一个SA

首先又给变回来了 2,1024,3
然后是MSG 三个不同的半径
利用query_ball_point()进行分组

最远点采样
选择了512个点[ 2,512,3] ,然后以这个512个点为圆心,给定的r进行分组,
进行分组
三个半径进行的操作分开做了,最后进行了合并在一起
以第一个为例

经过第一个半径的分组得到的数据为 [2,512,16,3]
然后又改为[2,3,16,512]
经过mlp进行特征提取结果为
[2,64,16,512]
解释一下:2-batchsize 64-channel 16是一个组的点的数量
再求最大值(就是对每一个组求最大值,pointnet中的maxpool一样)
得到最终的结果为[2,64,512]
这就是第一个分组进行mlp之后的结果

另外两个半径的结果分别为[2,128,512]、[2,128,512].然后将三个结果concat
最终为[2,320,512]

第二SA
送入SA的由两部分组成,
一个是经过最远点采样之后的点[2,3,512],以及MLP之后的结果为[2,320,512]
经过MLP之后的结果作为一部分特征放入到网络中,
提取之后的结果为
[2,3,128] 最远点采样的结果
[2,640,128] MLP之后的

第三次SA

[2,3,1]
[2,1024,1] MLP之后的

最后分类
针对[2,1024,1]的特征进行分类,MLP完成,比较简单。

源码

FPS 最远点采样,每一行都注释,其余代码都比较简单

首先定义了centroids和distance作为后续存储中心点和距离
随机选择一个点 作为最远点的开始,
然后求所有点到这个点的距离
得到一个新的距离矩阵dist
这个dist和distance进行比较,如果距离小,就将distance结果进行替换,
然后根据distance矩阵 选择最大值吗,也就距离最大的值。得到相应的索引
在求第三个点的时候呢,计算所有点到第二点的距离dist
和distance进行比较 注意这里的distance表示是什么意思,表示所有的点到第一个点的距离
这二者进行比较 也就说,所有的点分别到第一个点的距离 和第二点的距离,选择比较小的距离,更新到distance中。那么此时的distance就变成,所有的点到第一个点和第二点的距离中,最短的距离。这里比较关键,我一点点画图才明白
后面找距离最近的中 最大的确定为下一个点,继续这样的迭代。

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device #
    B, N, C = xyz.shape#
    #先定义处中心点的矩阵以及距离矩阵,中心点就是我们最远点采样的点数第一次是512
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)#2*512
    distance = torch.ones(B, N).to(device) * 1e10 #1024 先定义出来的距离矩阵数字很大的
    #最远点,第一个最远点是随机选择的
    #随机的索引,比如[224,518],两个值是因为batch size为2
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)#batch里每个样本随机初始化一个最远点的索引
    # 每个batch里面的索引
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint): #第一次的npoint为512,循环512ci
        centroids[:, i] = farthest #第一个采样点选随机初始化的索引
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)#得到当前采样点的坐标 B*3 
        dist = torch.sum((xyz - centroid) ** 2, -1)#计算当前采样点与其他点的距离
        mask = dist < distance#选择距离最近的来更新距离(更新维护这个表),一个包含bool构成的mask
        distance[mask] = dist[mask]#
        farthest = torch.max(distance, -1)[1]#重新计算得到最远点索引(在更新的表中选择距离最大的那个点)
    return centroid
举报

相关推荐

0 条评论