0
点赞
收藏
分享

微信扫一扫

pairdata求法向量

九点韶留学 2022-01-16 阅读 36

作用:改写了三个类,为的是在用geometric里面的pairdata的时候,使用一个类,传入pairdata,求出pairdata中的A和B的法向量啊,以及对坐标归一化啊之类的,

class Center(object):
    r"""Centers node positions around the origin."""

    def __call__(self, data):
        data.pos_A = data.pos_A - data.pos_A.mean(dim=-2, keepdim=True)
        data.pos_B = data.pos_B - data.pos_B.mean(dim=-2, keepdim=True)
        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

class Delaunay(object):
    r"""Computes the delaunay triangulation of a set of points."""
    def __call__(self, data):
        if data.pos_A.size(0) < 2:
            data.edge_index = torch.tensor([], dtype=torch.long,
                                           device=data.pos_A.device).view(2, 0)
        if data.pos_A.size(0) == 2:
            data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
                                           device=data.pos_A.device)
        elif data.pos_A.size(0) == 3:
            data.face_A = torch.tensor([[0], [1], [2]], dtype=torch.long,
                                     device=data.pos_A.device)
        if data.pos_A.size(0) > 3:
            pos_A = data.pos_A.cpu().numpy()
            tri = scipy.spatial.Delaunay(pos_A, qhull_options='QJ')
            face_A = torch.from_numpy(tri.simplices)

            data.face_A = face_A.t().contiguous().to(data.pos_A.device, torch.long)
        #----------chain B------------------------------
        if data.pos_B.size(0) < 2:
            data.edge_index = torch.tensor([], dtype=torch.long,
                                           device=data.pos_B.device).view(2, 0)
        if data.pos_B.size(0) == 2:
            data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
                                           device=data.pos_B.device)
        elif data.pos_B.size(0) == 3:
            data.face_B = torch.tensor([[0], [1], [2]], dtype=torch.long,
                                     device=data.pos_B.device)
        if data.pos_B.size(0) > 3:
            pos_B = data.pos_B.cpu().numpy()
            tri = scipy.spatial.Delaunay(pos_B, qhull_options='QJ')
            face_B = torch.from_numpy(tri.simplices)

            data.face_B = face_B.t().contiguous().to(data.pos_B.device, torch.long)

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

class GenerateMeshNormals(object):
    r"""Generate normal vectors for each mesh node based on neighboring
    faces."""

    def __call__(self, data):
        assert 'face_A' in data
        pos_A, face_A = data.pos_A, data.face_A

        vec1 = pos_A[face_A[1]] - pos_A[face_A[0]]
        vec2 = pos_A[face_A[2]] - pos_A[face_A[0]]
        face_norm_A = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]

        idx = torch.cat([face_A[0], face_A[1], face_A[2]], dim=0)
        face_norm_A = face_norm_A.repeat(3, 1)

        norm_A = scatter_add(face_norm_A, idx, dim=0, dim_size=pos_A.size(0))
        norm_A = F.normalize(norm_A, p=2, dim=-1)  # [N, 3]

        data.norm_A = norm_A

        #----------------------chain B----------------------
        assert 'face_B' in data
        pos_B, face_B = data.pos_B, data.face_B

        vec1 = pos_B[face_B[1]] - pos_B[face_B[0]]
        vec2 = pos_B[face_B[2]] - pos_B[face_B[0]]
        face_norm_B = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]

        idx = torch.cat([face_B[0], face_B[1], face_B[2]], dim=0)
        face_norm_B = face_norm_B.repeat(3, 1)

        norm_B = scatter_add(face_norm_B, idx, dim=0, dim_size=pos_B.size(0))
        norm_B = F.normalize(norm_B, p=2, dim=-1)  # [N, 3]

        data.norm_B = norm_B

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
举报

相关推荐

0 条评论