pointnet++源码学习
pointnet++中包含pointnet源码,因此只看pointnet++就可以了
整体流程
以debug的形式看源码,分析电源数据在网络中的变化过程,如何进行特征提取
以batch size = 2 为例
输入数据: 2, 1024, 3
->channel_first 2, 3, 1024
第一个SA
首先又给变回来了 2,1024,3
然后是MSG 三个不同的半径
利用query_ball_point()进行分组
最远点采样
选择了512个点[ 2,512,3] ,然后以这个512个点为圆心,给定的r进行分组,
进行分组
三个半径进行的操作分开做了,最后进行了合并在一起
以第一个为例
经过第一个半径的分组得到的数据为 [2,512,16,3]
然后又改为[2,3,16,512]
经过mlp进行特征提取结果为
[2,64,16,512]
解释一下:2-batchsize 64-channel 16是一个组的点的数量
再求最大值(就是对每一个组求最大值,pointnet中的maxpool一样)
得到最终的结果为[2,64,512]
这就是第一个分组进行mlp之后的结果
另外两个半径的结果分别为[2,128,512]、[2,128,512].然后将三个结果concat
最终为[2,320,512]
第二SA
送入SA的由两部分组成,
一个是经过最远点采样之后的点[2,3,512],以及MLP之后的结果为[2,320,512]
经过MLP之后的结果作为一部分特征放入到网络中,
提取之后的结果为
[2,3,128] 最远点采样的结果
[2,640,128] MLP之后的
第三次SA
[2,3,1]
[2,1024,1] MLP之后的
最后分类
针对[2,1024,1]的特征进行分类,MLP完成,比较简单。
源码
FPS 最远点采样,每一行都注释,其余代码都比较简单
首先定义了centroids和distance作为后续存储中心点和距离
随机选择一个点 作为最远点的开始,
然后求所有点到这个点的距离
得到一个新的距离矩阵dist
这个dist和distance进行比较,如果距离小,就将distance结果进行替换,
然后根据distance矩阵 选择最大值吗,也就距离最大的值。得到相应的索引
在求第三个点的时候呢,计算所有点到第二点的距离dist
和distance进行比较 注意这里的distance表示是什么意思,表示所有的点到第一个点的距离
这二者进行比较 也就说,所有的点分别到第一个点的距离 和第二点的距离,选择比较小的距离,更新到distance中。那么此时的distance就变成,所有的点到第一个点和第二点的距离中,最短的距离。这里比较关键,我一点点画图才明白
后面找距离最近的中 最大的确定为下一个点,继续这样的迭代。
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device #
B, N, C = xyz.shape#
#先定义处中心点的矩阵以及距离矩阵,中心点就是我们最远点采样的点数第一次是512
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)#2*512
distance = torch.ones(B, N).to(device) * 1e10 #1024 先定义出来的距离矩阵数字很大的
#最远点,第一个最远点是随机选择的
#随机的索引,比如[224,518],两个值是因为batch size为2
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)#batch里每个样本随机初始化一个最远点的索引
# 每个batch里面的索引
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint): #第一次的npoint为512,循环512ci
centroids[:, i] = farthest #第一个采样点选随机初始化的索引
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)#得到当前采样点的坐标 B*3
dist = torch.sum((xyz - centroid) ** 2, -1)#计算当前采样点与其他点的距离
mask = dist < distance#选择距离最近的来更新距离(更新维护这个表),一个包含bool构成的mask
distance[mask] = dist[mask]#
farthest = torch.max(distance, -1)[1]#重新计算得到最远点索引(在更新的表中选择距离最大的那个点)
return centroid