0
点赞
收藏
分享

微信扫一扫

nerf-pytorch3D 代码详解

钎探穗 2022-04-13 阅读 41
python

目录

model:class RadianceFieldRenderer

class RadianceFieldRenderer(torch.nn.Module):
    """
    Raybundler->class(orgin, direction, length, xys)

    文件实现了 RadianceFieldRenderer类,集成torch.nn.Modulerendering的前向传播过程如下:
    1.对于给定的输入 camera,rendering ray 通过“NeRFRaysampler” 产生。
    2.在training mode下,射线是一组“射线是一组“n_rays_per_image”的图像网格的随机2D位置。
    3.在evaluation mode下,光束对应的是整张图片的网格.光束被进一步分割为“chunk_size_test” 大小的块,以防止内存不足的错误.

    渲染前向传递过程如下:
        1) 对于给定的输入相机,渲染光线是用
            `self._renderer['coarse']` 的 `NeRFRaysampler` 对象。
            在训练模式下(`self.training==True`),光线是一组
                图像网格的“n_rays_per_image”随机二维位置。
            在评估模式下(`self.training==False`),光线对应
                到完整的图像网格。光线进一步分裂为
                `chunk_size_test`-大小的块,以防止内存不足错误。
        2) 对于每个射线点,评估粗略的 `NeuralRadianceField` MLP。
            指向此 MLP 的指针存储在 `self._implicit_function['coarse']` 中
        3) 粗略的辐射场用`self._renderer['coarse']` 的 `EmissionAbsorptionNeRFRaymarcher` 对象。
        4)粗raymarcher输出引导的概率分布用于精细渲染通道的重要性光线采样。这
            `ProbabilisticRaysampler` 存储在 `self._renderer['fine'].raysampler` 中
            实现重要性光线采样。
        5) 类似于 2) `self._implicit_function['fine']` 中的精细 MLP
            用占据和颜色标记射线点。
        6) self._renderer['fine'].raymarcher` 生成最终的精细渲染。
        7) 将精细渲染和粗糙渲染与地面实况输入图像进行比较与 PSNR 和 MSE 指标。
    """

def init

    def __init__(
        self,
        image_size: Tuple[int, int],#高*宽
        n_pts_per_ray: int,#在粗渲染过程中,沿着每条光线采样的点数
        n_pts_per_ray_fine: int,#在精细渲染过程中,沿着每条光线采样的点数
        n_rays_per_image: int,#训练时每张图片的光束数量
        min_depth: float,#用于粗渲染的采样射线点的最小深度
        max_depth: float,#用于粗渲染的采样射线点的最大深度
        stratified: bool,#训练时分层每个光束点的深度
        stratified_test: bool,#测试时分层每个光束点的深度
        chunk_size_test: int,#每个成像射线块中的射线数,目前此处有些问题 应该是测试时有效?
        n_harmonic_functions_xyz: int = 6,#谐波函数嵌入时的3D坐标 就是Positional encode的L
        n_harmonic_functions_dir: int = 4,#谐波函数嵌入时的方向值 就是Positional encode的L
        n_hidden_neurons_xyz: int = 256,#在MLP全连接层隐藏单元数量,该层接受3D点位置并输出具有中间特征的占用域
        n_hidden_neurons_dir: int = 128,#在MLP的完全连接层中,接受中间特征和光线方向并输出亮度场(逐点颜色)的隐藏单元数。(最后一层)
        n_layers_xyz: int = 8,##MLP的占用域的输出层数
        append_xyz: Tuple[int] = (5,),#占用 MLP 的跳过层的索引列表。在评估跳过层之前,输入到 MLP 的张量附加到跳过层输入。
        density_noise_std: float = 0.0,#将随机正态噪声的标准差加到占用MLP的输出上。仅在' self.training==True '时激活。
        visualization: bool = False,#是否为可视化存储额外的输出
    ):
        """
        Args:
            image_size: The size of the rendered image (`[height, width]`).
            n_pts_per_ray: The number of points sampled along each ray for the
                coarse rendering pass.
            n_pts_per_ray_fine: The number of points sampled along each ray for the
                fine rendering pass.
            n_rays_per_image: Number of Monte Carlo ray samples when training
                (`self.training==True`).
            min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
            max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
            stratified: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during training (`self.training==True`).
            stratified_test: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during evaluation (`self.training==False`).
            chunk_size_test: The number of rays in each chunk of image rays.
                Active only when `self.training==True`.
            n_harmonic_functions_xyz: The number of harmonic functions
                used to form the harmonic embedding of 3D point locations.
            n_harmonic_functions_dir: The number of harmonic functions
                used to form the harmonic embedding of the ray directions.
            n_hidden_neurons_xyz: The number of hidden units in the
                fully connected layers of the MLP that accepts the 3D point
                locations and outputs the occupancy field with the intermediate
                features.
            n_hidden_neurons_dir: The number of hidden units in the
                fully connected layers of the MLP that accepts the intermediate
                features and ray directions and outputs the radiance field
                (per-point colors).
            n_layers_xyz: The number of layers of the MLP that outputs the
                occupancy field.
            append_xyz: The list of indices of the skip layers of the occupancy MLP.
                Prior to evaluating the skip layers, the tensor which was input to MLP
                is appended to the skip layer input.
            density_noise_std: The standard deviation of the random normal noise
                added to the output of the occupancy MLP.
                Active only when `self.training==True`.
            visualization: whether to store extra output for visualization.
        """

        super().__init__()

        # The renderers and implicit functions are stored under the fine/coarse
        # keys in ModuleDict PyTorch modules.
        # 渲染器和隐式函数存储在fine/coarse下
        # ModuleDict PyTorch 模块中的键
        
        self._renderer = torch.nn.ModuleDict()##定义render(分coarse和fine) 两个key分别是coarse和fine
        self._implicit_function = torch.nn.ModuleDict()##定义网络隐层函数(分coarse和fine), 结构和renderer一致,不过对应的是NeuralRadiencefunction
        #self._implicit_function->n_harmonic_function_xyz_dir hidden_neurons_xyz_dir, n_layers, append_xyz
        # Init the EA raymarcher used by both passes.
        raymarcher = EmissionAbsorptionNeRFRaymarcher()##返回的是特征和权重

        # Parse out image dimensions.
        image_height, image_width = image_size#800 800

        for render_pass in ("coarse", "fine"):
            if render_pass == "coarse":
                # Initialize the coarse raysampler.
                #正常的均匀点采样
                raysampler = NeRFRaysampler(
                    n_pts_per_ray=n_pts_per_ray,##沿每条光线的采样点数
                    min_depth=min_depth,
                    max_depth=max_depth,
                    stratified=stratified,#训练时分层每个光束点的深度
                    stratified_test=stratified_test,
                    n_rays_per_image=n_rays_per_image,
                    image_height=image_height,
                    image_width=image_width,
                )
            elif render_pass == "fine":
                #重要性采样
                # Initialize the fine raysampler.
                raysampler = ProbabilisticRaysampler(
                    n_pts_per_ray=n_pts_per_ray_fine,
                    stratified=stratified,
                    stratified_test=stratified_test,
                )
            else:
                raise ValueError(f"No such rendering pass {render_pass}")

            # Initialize the fine/coarse renderer.初始化fine/coarse render
            self._renderer[render_pass] = ImplicitRenderer(
                raysampler=raysampler,
                raymarcher=raymarcher,
            )

            # Instantiate the fine/coarse NeuralRadianceField module.网络实例化
            self._implicit_function[render_pass] = NeuralRadianceField(
                n_harmonic_functions_xyz=n_harmonic_functions_xyz,
                n_harmonic_functions_dir=n_harmonic_functions_dir,
                n_hidden_neurons_xyz=n_hidden_neurons_xyz,
                n_hidden_neurons_dir=n_hidden_neurons_dir,
                n_layers_xyz=n_layers_xyz,
                append_xyz=append_xyz,
            )

        self._density_noise_std = density_noise_std
        self._chunk_size_test = chunk_size_test
        self._image_size = image_size
        self.visualization = visualization

def precache_rays

    def precache_rays(
        self,
        cache_cameras: List[CamerasBase], #precache ray的n个camera的列表
        cache_camera_hashes: List[str], #每个相机独特标识符的列表
    ):
        """
        预缓存从相机“cache_cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`cache_camera_hashes`。
        缓存的光线被移动到 cpu 并存储在 `self._renderer['coarse']._ray_cache`。
        缓存具有相同哈希的两个相机时引发 `ValueError`。

        参数:
            cache_cameras:预先缓存了光线的“N”个摄像机的列表。
            cache_camera_hashes:每个唯一标识符的“N”个列表

        Precaches the rays emitted from the list of cameras `cache_cameras`,
        where each camera is uniquely identified with the corresponding hash
        from `cache_camera_hashes`.

        The cached rays are moved to cpu and stored in
        `self._renderer['coarse']._ray_cache`.

        Raises `ValueError` when caching two cameras with the same hash.

        Args:
            cache_cameras: A list of `N` cameras for which the rays are pre-cached.
            cache_camera_hashes: A list of `N` unique identifiers for each
                camera from `cameras`.
        """
        self._renderer["coarse"].raysampler.precache_rays(##预缓存光线
            cache_cameras,
            cache_camera_hashes,
        )

def _process_ray_chunk

    def _process_ray_chunk(
        self,
        camera_hash: Optional[str],# pre-cached camera的唯一标识符
        camera: CamerasBase,# 一批场景被渲染的cameara
        image: torch.Tensor,# ground truth , shape(batch_size,,3)
        chunk_idx: int, # 当前射线块的索引。 
    ) -> dict:
        """
        Samples and renders a chunk of rays.

        Args:
            camera_hash: A unique identifier of a pre-cached camera.
                If `None`, the cache is not searched and the sampled rays are
                calculated from scratch.
            camera: A batch of cameras from which the scene is rendered.
            image: A batch of corresponding ground truth images of shape
                ('batch_size', ·, ·, 3).
            chunk_idx: The index of the currently rendered ray chunk.
        Returns:
            out: `dict` containing the outputs of the rendering:
                `rgb_coarse`: The result of the coarse rendering pass.
                `rgb_fine`: The result of the fine rendering pass.
                `rgb_gt`: The corresponding ground-truth RGB values.
        return :
             out: `dict` 包含渲染的输出:
                  `rgb_coarse`:粗略渲染过程的结果。
                  `rgb_fine`:精细渲染通道的结果。
                  `rgb_gt`:对应的ground-truth RGB值。
        """
        # Initialize the outputs of the coarse rendering to None.
        coarse_ray_bundle = None##orgin direction depth xy二维坐标
        coarse_weights = None

        # First evaluate the coarse rendering pass, then the fine one.
        for renderer_pass in ("coarse", "fine"):
            (rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
                cameras=camera,
                volumetric_function=self._implicit_function[renderer_pass],
                chunksize=self._chunk_size_test,
                chunk_idx=chunk_idx,
                density_noise_std=(self._density_noise_std if self.training else 0.0),
                input_ray_bundle=coarse_ray_bundle,
                ray_weights=coarse_weights,
                camera_hash=camera_hash,
            )

            if renderer_pass == "coarse":
                rgb_coarse = rgb
                # Store the weights and the rays of the first rendering pass
                # for the ensuing importance ray-sampling of the fine render.
                #存储第一个渲染通道的权重和光线,用于精细渲染的后续重要性光线采样。
                coarse_ray_bundle = ray_bundle_out##第一次粗渲染的结果保存用于精细渲染
                coarse_weights = weights
                if image is not None:
                    # Sample the ground truth images at the xy locations of the
                    # rendering ray pixels.
                    # 在xy位置对ground truth图像进行采样,渲染光线像素。
                    rgb_gt = sample_images_at_mc_locs(
                        image[..., :3][None],
                        ray_bundle_out.xys,
                    )
                else:
                    rgb_gt = None

            elif renderer_pass == "fine":
                rgb_fine = rgb

            else:
                raise ValueError(f"No such rendering pass {renderer_pass}")

        out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
        if self.visualization:
            # Store the coarse rays/weights only for visualization purposes.
            # 存储粗射线/权重仅用于可视化目的。
            out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
                *[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
            )
            out["coarse_weights"] = coarse_weights.detach().cpu()
            ##切断梯度
        return out

def forward

    def forward(
        self,
        camera_hash: Optional[str],
        camera: CamerasBase,
        image: torch.Tensor,
    ) -> Tuple[dict, dict]:
        """
        从输入的“camera”的角度执行辐射场的粗略和精细的渲染传递。
        之后,两种渲染方式都与输入的地面真实“图像”相比较。
        通过评估峰值信噪比和均方误差.

        The rendering result depends on the `self.training` flag:
            - In the training mode (`self.training==True`), 
            该函数renders图像射线的随机子集(MonteCarlo rendering)。
            - In evaluation mode (`self.training==False`), 
            该函数renders整个图像。为了防止内存不足的错误,
            当 `self.training==False` 时,对光线进行采样和渲染
              批量大小为“chunksize”。

       参数:
            camera_hash:预缓存相机的唯一标识符。
                如果为“None”,则不搜索缓存并且采样光线
                从头计算。
            camera:渲染场景的一批相机。
            image: 一批对应形状为
            ('batch_size', ·, ·, 3)的ground truth图像
            
        Returns:
            out: `dict` 包含渲染的输出:
                `rgb_coarse`:粗略渲染过程的结果。
                `rgb_fine`:精细渲染通道的结果。
                `rgb_gt`:对应的ground-truth RGB值。

                The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
                `self.training` flag:
                    If `==True`, all 3 tensors are of shape
                    `(batch_size, n_rays_per_image, 3)` and contain the result
                    of the Monte Carlo training rendering pass.
                    If `==False`, all 3 tensors are of shape
                    `(batch_size, image_size[0], image_size[1], 3)` and contain
                    the result of the full image rendering pass.
                    
            metrics: `dict` 包含比较细和粗渲染到实际情况的误差度量:
                `mse_coarse`:粗略渲染和渲染之间的均方误差
                    输入`图像`
                `mse_fine`:精细渲染和渲染之间的均方误差
                    输入`图像`
                `psnr_coarse`:粗略渲染和粗糙渲染之间的峰值信噪比
                    输入`图像`
                `psnr_fine`:精细渲染和精细渲染之间的峰值信噪比
                    输入`图像`
        """
      
        if not self.training:
            # Full evaluation pass.
            n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(#获取chunk数即块数
                self._chunk_size_test,
                camera.R.shape[0],
            )
        else:
            # 如果在测试的话完整的渲染.
            n_chunks = 1

        # Process the chunks of rays.
        '''return :
             out: `dict` 包含渲染的输出:
                  `rgb_coarse`:粗略渲染过程的结果。
                  `rgb_fine`:精细渲染通道的结果。
                  `rgb_gt`:对应的ground-truth RGB值。'''
        chunk_outputs = [
            self._process_ray_chunk(
                camera_hash,
                camera,
                image,
                chunk_idx,
            )
            for chunk_idx in range(n_chunks)
        ]

        if not self.training:
            # For a full render pass concatenate the output chunks,
            # and reshape to image size.
            # 对于完整的渲染过程,连接输出块,并重塑为图像大小。
            out = {
                k: torch.cat(
                    [ch_o[k] for ch_o in chunk_outputs],
                    dim=1,
                ).view(-1, *self._image_size, 3)
                if chunk_outputs[0][k] is not None
                else None
                for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
            }
        else:
            out = chunk_outputs[0]

        # Calc the error metrics.
        metrics = {}
        if image is not None:
            for render_pass in ("coarse", "fine"):
                for metric_name, metric_fun in zip(
                    ("mse", "psnr"), (calc_mse, calc_psnr)
                ):
                    metrics[f"{metric_name}_{render_pass}"] = metric_fun(
                        out["rgb_" + render_pass][..., :3],
                        out["rgb_gt"][..., :3],
                    )

        return out, metrics

stats object

    # Init the stats object.
    if stats is None:
        stats = Stats(
            ["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
        )

class Stats:

def init

class Stats:
    """
	统计日志对象,用于收集在PyTorch中训练深度网络的统计数据。

    Example:
        ```
        # 初始化统计数据结构,用于记录统计数据“objective”和“top1e”.
        stats = Stats( ('objective','top1e') )

        network = init_net()  # init a pytorch module (=neural network)
        dataloader = init_dataloader()  # init a dataloader

        for epoch in range(10):

            # start of epoch -> call new_epoch
            stats.new_epoch()

            # Iterate over batches.
            for batch in dataloader:
                # 运行模型并将其保存到输出变量的“output”中。
                output = network(batch)

                # stats.update() 自动解析 the 'objective' and 'top1e'
                # from the "output" dict 并把这个存储在数据库里。
                stats.update(output)
                stats.print() # 打印给定epoch的平均值

            # Stores the training plots into '/tmp/epoch_stats.pdf'
            # 并绘制成运行在localhost上的Visdom服务器(如果运行的话)。
            stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
        ```
    """

    def __init__(
        self,
        log_vars: List[str],
        verbose: bool = False,
        epoch: int = -1,
        plot_file: Optional[str] = None,
    ) -> None:
        """
        Args:
            log_vars: 要记录的变量名列表。
            verbose: 打印状态信息。
            epoch: The initial epoch of the object.
            plot_file: The path to the file that will hold the training plots.
        """
        self.verbose = verbose
        self.log_vars = log_vars
        self.plot_file = plot_file
        self.hard_reset(epoch=epoch)

在这里插入图片描述

def hard_reset(self, epoch: int = -1) -> None

self.hard_reset(epoch=epoch)

def hard_reset(self, epoch: int = -1) -> None:
        """
        删除所有记录的数据。
        """
        self._epoch_start = None
        self.epoch = epoch
        if self.verbose:
            print("stats: epoch %d - hard reset" % self.epoch)
        self.stats = {}
        self.reset()

def reset(self) -> None

self.reset()

def reset(self) -> None:
        """
        在一个epoch 之前调用以清除当前的epoch缓冲区。
        """
        stat_sets = list(self.stats.keys())
        if self.verbose:
            print("stats: epoch %d - reset" % self.epoch)
        self.it = {k: -1 for k in stat_sets}
        for stat_set in stat_sets:
            for stat in self.stats[stat_set]:
                self.stats[stat_set][stat].reset()

        # 设置一个新的时间戳。
        self._epoch_start = time.time()

在这里插入图片描述

学习率scheduler 设置

learning rate: current_lr = base_lr * gamma ** (epoch / step_size)

    # 学习速率调度器设置.

    # 在原始代码之后, we use exponential decay of the#指数衰减
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):  ##定义学习率指数衰减
        return cfg.optimizer.lr_scheduler_gamma ** (
                epoch / cfg.optimizer.lr_scheduler_step_size
        )

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    # 利用学习率调度器实现lr——lambda的指数衰减学习率
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
    )

可视化设置

    # 初始化缓存,以存储可视化所需的变量
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

在这里插入图片描述

数据集

1、加载training/validation数据。

    # Load the training/validation data.
    train_dataset, val_dataset, _ = get_nerf_datasets(
        dataset_name=cfg.data.dataset_name,
        image_size=cfg.data.image_size,  # 加载初始化中的训练集和验证集
    )

def get_nerf_datasets

##返回图像, 相机参数, 相机编号3个组成的数据结构
def get_nerf_datasets(
    dataset_name: str,  # 'lego | fern' 给定具体的场景
    image_size: Tuple[int, int], #图像的尺寸(height,width)
    data_root: str = DEFAULT_DATA_ROOT, #数据的网络连接
    autodownload: bool = True, #根据网络连接对数据直接进行下载
) -> Tuple[Dataset, Dataset, Dataset]: ##返回训练集,验证集,测试集
    """
    获取使用“dataset_name”参数指定的数据集的training and validation DataSet对象

    Args:
        dataset_name: 要加载的数据集的名称。
        image_size: 表示加载的数据集图像大小的元组(高度、宽度)。
        data_root: 存储数据的根文件夹。
        autodownload: 自动下载数据集文件,以防它们丢失。

    Returns:
        train_dataset: The training dataset object.
        val_dataset: The validation dataset object.
        test_dataset: The testing dataset object.
    """

    if dataset_name not in ALL_DATASETS:
        raise ValueError(f"'{dataset_name}'' does not refer to a known dataset.")

    print(f"Loading dataset {dataset_name}, image size={str(image_size)} ...")

    cameras_path = os.path.join(data_root, dataset_name + ".pth")#cameras_path
    image_path = cameras_path.replace(".pth", ".png")#image_path

    if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
        # Automatically download the data files if missing.自动下载缺失的数据
        download_data((dataset_name,), data_root=data_root)

    train_data = torch.load(cameras_path)##获得训练数据
    n_cameras = train_data["cameras"]["R"].shape[0]##获得相机个数 400

    _image_max_image_pixels = Image.MAX_IMAGE_PIXELS #python image 库中设定好的image像素的最大值
    Image.MAX_IMAGE_PIXELS = None  # The dataset image is very large ...
    images = torch.FloatTensor(np.array(Image.open(image_path))) / 255.0 
    ##求梯度只针对于浮点型tensor,8bit原始图像归一化到[0,1]的(torch.FloatTensor)
    images = torch.stack(torch.chunk(images, n_cameras, dim=0))[..., :3]##torch.chunk分块
    ##[H, W, 3] (本例取H=W=400)
    Image.MAX_IMAGE_PIXELS = _image_max_image_pixels

    #确定图像的放缩比例 并调整图像
    scale_factors = [s_new / s for s, s_new in zip(images.shape[1:3], image_size)]
    if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
        raise ValueError(
            "Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument."
        )
    scale_factor = sum(scale_factors) * 0.5

    if scale_factor != 1.0:
        print(f"Rescaling dataset (factor={scale_factor})")
        images = torch.nn.functional.interpolate(
            images.permute(0, 3, 1, 2),
            size=tuple(image_size),
            mode="bilinear",
        ).permute(0, 2, 3, 1)##通过插值进行缩放

    #camera是pytorch3d.renderer.cameras.PerspectiveCameras的实例
    #主要用到的是选择矩阵旋转R,平移矩阵T,焦距和主视点 
    cameras = [
        PerspectiveCameras(
            **{k: v[cami][None] for k, v in train_data["cameras"].items()}
        ).to("cpu") ##提取train_data信息到cameras中
        for cami in range(n_cameras)##共建立n_cameras个相机实例
    ]

    ##将train_data划分为train, val, test
    train_idx, val_idx, test_idx = train_data["split"]

    train_dataset, val_dataset, test_dataset = [
        ListDataset(
            [
                {"image": images[i], "camera": cameras[i], "camera_idx": int(i)}
                for i in idx
            ]
        )
        for idx in [train_idx, val_idx, test_idx]
    ]

    return train_dataset, val_dataset, test_dataset

class ListDataset(Dataset)

class ListDataset(Dataset):
    """
    由entries(条目)列表组成的简单数据集。
    """

    def __init__(self, entries: List) -> None:
        """
        Args:
            entries: The list of dataset entries.
        """
        self._entries = entries

    def __len__(
        self,
    ) -> int:
        return len(self._entries)

    def __getitem__(self, index):
        return self._entries[index] ##直接获取该索引对应的列表内容

debug

在这里插入图片描述
在这里插入图片描述
train_data
在这里插入图片描述
images = torch.stack(torch.chunk(images, n_cameras, dim=0))[…, :3]##torch.chunk分块
在这里插入图片描述
在这里插入图片描述
scale_factors在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2、 预缓存投影光线

    if cfg.data.precache_rays:
        # Precache the projection rays.预缓存投影光线
        model.eval()  # 不启用 batchnorm and dropout
        with torch.no_grad():
            for dataset in (train_dataset, val_dataset):
                cache_cameras = [e["camera"].to(device) for e in dataset]
                cache_camera_hashes = [e["camera_idx"] for e in dataset]
                # 把 dataset中的camera和camera_idx拿出来进行缓存
                model.precache_rays(cache_cameras, cache_camera_hashes)

在这里插入图片描述
在这里插入图片描述

model.precache_rays

model.precache_rays(cache_cameras, cache_camera_hashes)

    def precache_rays(
        self,
        cache_cameras: List[CamerasBase], #precache ray的n个camera的列表
        cache_camera_hashes: List[str], #每个相机独特标识符的列表
    ):
        """
        预缓存从相机“cache_cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`cache_camera_hashes`。
        缓存的光线被移动到 cpu 并存储在 `self._renderer['coarse']._ray_cache`。
        缓存具有相同哈希的两个相机时引发 `ValueError`。

        参数:
            cache_cameras:预先缓存了光线的“N”个摄像机的列表。
            cache_camera_hashes:每个唯一标识符的“N”个列表
        """
        self._renderer["coarse"].raysampler.precache_rays(##预缓存光线
            cache_cameras,
            cache_camera_hashes,
        )

NeRFRaysampler

在这里插入图片描述

class NeRFRaysampler(torch.nn.Module)
    def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):

def precache_rays

self._renderer[“coarse”].raysampler.precache_rays

    def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
        """        
        预缓存从相机“cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`camera_hashes`。
        缓存的光线被移动到 cpu 并存储在 `self._ray_cache` 中,缓存具有相同哈希的两个相机时引发 `ValueError`。
        Args::
            cameras:预先缓存了光线的“N”个相机的列表。
            camera_hashes:每个的“N”个唯一标识符的列表,来自“cameras”的相机。
        """
        print(f"Precaching {len(cameras)} ray bundles ...")
        full_chunksize = ( #4096000
            self._grid_raysampler._xy_grid.numel()
            // 2
            * self._grid_raysampler._n_pts_per_ray
        )##每个chunk最多预缓存的光线数
        if self.get_n_chunks(full_chunksize, 1) != 1:##只能是一个chunk来预缓存光线
            raise ValueError("There has to be one chunk for precaching rays!")
        for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
            ray_bundle = self.forward(
                camera,
                caching=True,
                chunksize=full_chunksize,
            )
            if camera_hash in self._ray_cache:
                raise ValueError("There are redundant cameras!")
            self._ray_cache[camera_hash] = RayBundle( ##将光线预缓存到列表cpu
                *[v.to("cpu").detach() for v in ray_bundle]
            )
            self._print_precaching_progress(camera_i, len(cameras))
        print("")

debug

self在这里插入图片描述

for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

def get_n_chunks

self.get_n_chunks(full_chunksize, 1)

    def get_n_chunks(self, chunksize: int, batch_size: int):
        """
        返回“chunksize”大小的光线采样器的光线块的总数。
        参数:
            chunksize:每个块的光线数量。
            batch_size:光线采样器的批次大小。

        return:
            n_chunks:块的总数。
        """
        return int(
            math.ceil(#向上取整 numel->返回元素个数
                (self._grid_raysampler._xy_grid.numel() * 0.5 * batch_size) / chunksize
            )
        )

def forword

ray_bundle = self.forward(
                camera,
                caching=True,
                chunksize=full_chunksize,
            )

在这里插入图片描述

    def forward(
        self,
        cameras: CamerasBase,
        chunksize: int = None,
        chunk_idx: int = 0,
        camera_hash: str = None,
        caching: bool = False,
        **kwargs,
    ) -> RayBundle:
        """
        Args:
            cameras: 发出射线的一批“batch_size”照相机。
            chunksize: 每块射线的数量。只在下列情况下激活 `self.training==False`.
            chunk_idx: 射线块的索引。号码必须在
                `[0, self.get_n_chunks(chunksize, batch_size)-1]`.
                只在下列情况下激活  `self.training==False`.
            camera_hash: 预缓存相机的唯一标识符。 If `None`,
                缓存不被搜索,光线是从头开始计算的。
            caching: If `True`, 激活缓存模式,该模式返回应该存储到缓存中的“RayBundle”。
        Returns:
            一个名为“RayBundle”的元组,具有以下字段:
                origins: A tensor of shape
                    `(batch_size, n_rays_per_image, 3)`
                    在世界坐标中表示射线原点的位置。
                directions: A tensor of shape
                    `(batch_size, n_rays_per_image, 3)`
                    表示世界坐标中每一条射线的方向。
                lengths: A tensor of shape
                    `(batch_size, n_rays_per_image, n_pts_per_ray)`
                    包含世界单位中每条射线的z坐标(=depth).
                xys: A tensor of shape
                    `(batch_size, n_rays_per_image, 2)`
                    包含每个射线的2D图像坐标。
        """
        batch_size = cameras.R.shape[0]  # pyre-ignore 1
        device = cameras.device

        if (camera_hash is None) and (not caching) and self.training:
            # Sample random rays from scratch.从头开始采集光线
            ray_bundle = self._mc_raysampler(cameras)
            ray_bundle = self._normalize_raybundle(ray_bundle)
        else:
            if camera_hash is not None:
                # 我们从缓存中取回相机的情况.
                if batch_size != 1:
                    raise NotImplementedError(
                        "Ray caching works only for batches with a single camera!"
                    )
                full_ray_bundle = self._ray_cache[camera_hash]#从光线缓存中取出光线
            else:
                # 我们从零开始生成一个完整的射线网格。
                full_ray_bundle = self._grid_raysampler(cameras)
                full_ray_bundle = self._normalize_raybundle(full_ray_bundle)

            n_pixels = full_ray_bundle.directions.shape[:-1].numel() #640000

            if self.training:
                # During training we randomly subsample rays.对光线进行二次采样
                sel_rays = torch.randperm(n_pixels, device=device)[
                    : self._mc_raysampler._n_rays_per_image
                ]
            else:
                # 如果我们进行测试,我们只接受所请求的块.直接取chunk
                if chunksize is None:
                    chunksize = n_pixels * batch_size
                start = chunk_idx * chunksize * batch_size #0
                end = min(start + chunksize, n_pixels)      #640000
                sel_rays = torch.arange(
                    start,
                    end,
                    dtype=torch.long,
                    device=full_ray_bundle.lengths.device,
                )

            # 把整束射线中的“sel_rays”取出来。
            ray_bundle = RayBundle(
                *[
                    v.view(n_pixels, -1)[sel_rays]
                    .view(batch_size, sel_rays.numel() // batch_size, -1)
                    .to(device)
                    for v in full_ray_bundle
                ]
            )

        if (
            (self._stratified and self.training)
            or (self._stratified_test and not self.training)
        ) and not caching:  # 请确保缓存时不要分层!
            ray_bundle = self._stratify_ray_bundle(ray_bundle)

        return ray_bundle

debug

self:

NeRFRaysampler(
  (_grid_raysampler): NDCMultinomialRaysampler()
  (_mc_raysampler): MonteCarloRaysampler()
)

在这里插入图片描述

full_ray_bundle = self._grid_raysampler(cameras)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

def _normalize_raybundle

ray_bundle = self._normalize_raybundle(ray_bundle)
    def _normalize_raybundle(self, ray_bundle: RayBundle):
        """
        将输入“RayBundle”的光线方向标准化为单位范数。
        """
        ray_bundle = ray_bundle._replace(
            directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
        )
        return ray_bundle

def _print_precaching_progress

self._print_precaching_progress(camera_i, len(cameras))
    def _print_precaching_progress(self, i, total, bar_len=30):
        """
        打印光线缓存的进度条
        """
        position = round((i + 1) / total * bar_len)
        pbar = "[" + "█" * position + " " * (bar_len - position) + "]"
        print(pbar, end="\r")

3、train_dataloader、val_dataloader

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        collate_fn=trivial_collate,
    )

    # 验证数据器只是一个无休止的随机样本流。
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        num_workers=0,
        collate_fn=trivial_collate,
        sampler=torch.utils.data.RandomSampler(
            val_dataset,
            replacement=True,
            num_samples=cfg.optimizer.max_epochs,
        ),
    )

训练

    # 接下来就是正常的训练过程
    # Set the model to the training mode.
    model.train()

epoch循环

    # Run the main training loop.
    for epoch in range(start_epoch, cfg.optimizer.max_epochs):
        stats.new_epoch()  # Init a new epoch.建立一个新的epoch训练

batch循环

1、提取数据

for iteration, batch in enumerate(train_dataloader):
            image, camera, camera_idx = batch[0].values()  # 第一个epoch对应的值
            image = image.to(device)
            camera = camera.to(device)

            optimizer.zero_grad()

在这里插入图片描述
在这里插入图片描述

2、运行模型

执行RadianceFieldRenderer模型的forword文件
代码在上面

            # Run the forward pass of the model.
            nerf_out, metrics = model(
                camera_idx if cfg.data.precache_rays else None,
                camera,
                image,
            )

在这里插入图片描述

model. self

RadianceFieldRenderer(
  (_renderer): ModuleDict(
    (coarse): ImplicitRenderer(
      (raysampler): NeRFRaysampler(
        (_grid_raysampler): NDCMultinomialRaysampler()
        (_mc_raysampler): MonteCarloRaysampler()
      )
      (raymarcher): EmissionAbsorptionNeRFRaymarcher()
    )
    (fine): ImplicitRenderer(
      (raysampler): ProbabilisticRaysampler()
      (raymarcher): EmissionAbsorptionNeRFRaymarcher()
    )
  )
  (_implicit_function): ModuleDict(
    (coarse): NeuralRadianceField(
      (harmonic_embedding_xyz): HarmonicEmbedding()
      (harmonic_embedding_dir): HarmonicEmbedding()
      (mlp_xyz): MLPWithInputSkips(
        (mlp): ModuleList(
          (0): Sequential(
            (0): Linear(in_features=63, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (1): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (2): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (3): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (4): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (5): Sequential(
            (0): Linear(in_features=319, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (6): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (7): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
        )
      )
      (intermediate_linear): Linear(in_features=256, out_features=256, bias=True)
      (density_layer): Linear(in_features=256, out_features=1, bias=True)
      (color_layer): Sequential(
        (0): LinearWithRepeat()
        (1): ReLU(inplace=True)
        (2): Linear(in_features=128, out_features=3, bias=True)
        (3): Sigmoid()
      )
    )
    (fine): NeuralRadianceField(
      (harmonic_embedding_xyz): HarmonicEmbedding()
      (harmonic_embedding_dir): HarmonicEmbedding()
      (mlp_xyz): MLPWithInputSkips(
        (mlp): ModuleList(
          (0): Sequential(
            (0): Linear(in_features=63, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (1): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (2): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (3): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (4): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (5): Sequential(
            (0): Linear(in_features=319, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (6): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
          (7): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): ReLU(inplace=True)
          )
        )
      )
      (intermediate_linear): Linear(in_features=256, out_features=256, bias=True)
      (density_layer): Linear(in_features=256, out_features=1, bias=True)
      (color_layer): Sequential(
        (0): LinearWithRepeat()
        (1): ReLU(inplace=True)
        (2): Linear(in_features=128, out_features=3, bias=True)
        (3): Sigmoid()
      )
    )
  )
)

chunk_outputs

在这里插入图片描述

在这里插入图片描述

metrics

coarse
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
fine在这里插入图片描述

3、loss

            # The loss is a sum of coarse and fine MSEs metrics指标
            loss = metrics["mse_coarse"] + metrics["mse_fine"]

            # Take the training step.
            loss.backward()
            optimizer.step()

loss:

tensor(0.2945, device='cuda:0', grad_fn=<AddBackward0>)

4、stats

            # Update stats with the current metrics.
            stats.update(
                {"loss": float(loss), **metrics},
                stat_set="train",
            )

            if iteration % cfg.stats_print_interval == 0:
                stats.print(stat_set="train")  ##每隔10print出来一次

stats

{'train': 
{'loss': <nerf.stats.AverageMeter object at 0x7f394e0f2970>, 
'mse_coarse': <nerf.stats.AverageMeter object at 0x7f394e104b50>, 
'mse_fine': <nerf.stats.AverageMeter object at 0x7f394e0f2850>, 
'psnr_coarse': <nerf.stats.AverageMeter object at 0x7f394e0f2070>, 
'psnr_fine': <nerf.stats.AverageMeter object at 0x7f394e114730>,
 'sec/it': <nerf.stats.AverageMeter object at 0x7f394e129310>}
}

在这里插入图片描述

5、可视化更新

            # Update the visualization cache.
            if viz is not None:  ##可视化更新
                visuals_cache.append(
                    {
                        "camera": camera.cpu(),
                        "camera_idx": camera_idx,
                        "image": image.cpu().detach(),  ##detach不需要梯度
                        "rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
                        "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
                        "rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
                        "coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
                    }
                )

调整学习速度。

        # Adjust the learning rate.
        lr_scheduler.step()

验证

        # Validation
        if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:

            # Sample a validation camera/image.
            val_batch = next(val_dataloader.__iter__())
            val_image, val_camera, camera_idx = val_batch[0].values()
            val_image = val_image.to(device)
            val_camera = val_camera.to(device)

            # 激活模型的val模式(让我们完成一个完整的渲染传递)。
            model.eval()
            with torch.no_grad():
                val_nerf_out, val_metrics = model(
                    camera_idx if cfg.data.precache_rays else None,
                    val_camera,
                    val_image,
                )

            # Update stats with the validation metrics. ##更新指标 即损失函数的加权
            stats.update(val_metrics, stat_set="val")
            stats.print(stat_set="val")

            if viz is not None:
                # Plot that loss curves into visdom.
                stats.plot_stats(
                    viz=viz,
                    visdom_env=cfg.visualization.visdom_env,
                    plot_file=None,
                )
                # 可视化中间结果。
                visualize_nerf_outputs(
                    val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env
                )

            # Set the model back to train mode.
            model.train()

Checkpoint

        # Checkpoint.
        if (
                epoch % cfg.checkpoint_epoch_interval == 0
                and len(cfg.checkpoint_path) > 0
                and epoch > 0
        ):
            print(f"Storing checkpoint {checkpoint_path}.")
            data_to_store = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "stats": pickle.dumps(stats),
            }
            torch.save(data_to_store, checkpoint_path)  ##保存训练模型参数数据到checkpoint

class Stats 其他函数:

def new_epoch

    def new_epoch(self) -> None:
        """
        Initializes a new epoch.
        """
        if self.verbose:
            print("stats: new epoch %d" % (self.epoch + 1))
        self.epoch += 1  # increase epoch counter
        self.reset()  # zero the stats

update

# Update stats with the current metrics.
stats.update(
    {"loss": float(loss), **metrics},
    stat_set="train",
)
    def update(self, preds: dict, stat_set: str = "train") -> None:
        """
        Update the internal logs with metrics of a training step.
        每个指标都存储为AverageMeter的实例。

        Args:
            preds: 要添加到日志中的值的Dict
            stat_set: 待更新的统计数据集 (e.g. "train", "val").
        """

        if self.epoch == -1:  # uninitialized
            warnings.warn(
                "self.epoch==-1 means uninitialized stats structure"
                " -> new_epoch() called"
            )
            self.new_epoch()

        if stat_set not in self.stats:
            self.stats[stat_set] = {}
            self.it[stat_set] = -1

        self.it[stat_set] += 1

        epoch = self.epoch
        it = self.it[stat_set]

        for stat in self.log_vars:

            if stat not in self.stats[stat_set]:
                self.stats[stat_set][stat] = AverageMeter()

            if stat == "sec/it":  # compute speed
                elapsed = time.time() - self._epoch_start
                time_per_it = float(elapsed) / float(it + 1)
                val = time_per_it
            else:
                if stat in preds:
                    val = self._gather_value(preds[stat])
                else:
                    val = None

            if val is not None:
                self.stats[stat_set][stat].update(val, epoch=epoch, n=1)

    def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
        """
        Print the current values of all stored stats.

        Args:
            max_it: Maximum iteration number to be displayed.
                If None, the maximum iteration number is not displayed.
            stat_set: The set of statistics to be printed.
        """

        epoch = self.epoch
        stats = self.stats

        str_out = ""

        it = self.it[stat_set]
        stat_str = ""
        stats_print = sorted(stats[stat_set].keys())
        for stat in stats_print:
            if stats[stat_set][stat].count == 0:
                continue
            stat_str += " {0:.12}: {1:1.3f} |".format(stat, stats[stat_set][stat].avg)

        head_str = f"[{stat_set}] | epoch {epoch} | it {it}"
        if max_it:
            head_str += f"/ {max_it}"

        str_out = f"{head_str} | {stat_str}"

        print(str_out)
    def _gather_value(self, val):
        if isinstance(val, float):
            pass
        else:
            val = val.data.cpu().numpy()
            val = float(val.sum())
        return val

    def plot_stats(
        self,
        viz: Visdom = None,
        visdom_env: Optional[str] = None,
        plot_file: Optional[str] = None,
    ) -> None:
        """
        Plot the line charts of the history of the stats.

        Args:
            viz: The Visdom object holding the connection to a Visdom server.
            visdom_env: The visdom environment for storing the graphs.
            plot_file: The path to a file with training plots.
        """

        stat_sets = list(self.stats.keys())

        if viz is None:
            withvisdom = False
        elif not viz.check_connection():
            warnings.warn("Cannot connect to the visdom server! Skipping visdom plots.")
            withvisdom = False
        else:
            withvisdom = True

        lines = []

        for stat in self.log_vars:
            vals = []
            stat_sets_now = []
            for stat_set in stat_sets:
                val = self.stats[stat_set][stat].get_epoch_averages()
                if val is None:
                    continue
                else:
                    val = np.array(val).reshape(-1)
                    stat_sets_now.append(stat_set)
                vals.append(val)

            if len(vals) == 0:
                continue

            vals = np.stack(vals, axis=1)
            x = np.arange(vals.shape[0])

            lines.append((stat_sets_now, stat, x, vals))

        if withvisdom:
            for tmodes, stat, x, vals in lines:
                title = "%s" % stat
                opts = {"title": title, "legend": list(tmodes)}
                for i, (tmode, val) in enumerate(zip(tmodes, vals.T)):
                    update = "append" if i > 0 else None
                    valid = np.where(np.isfinite(val))
                    if len(valid) == 0:
                        continue
                    viz.line(
                        Y=val[valid],
                        X=x[valid],
                        env=visdom_env,
                        opts=opts,
                        win=f"stat_plot_{title}",
                        name=tmode,
                        update=update,
                    )

        if plot_file is None:
            plot_file = self.plot_file

        if plot_file is not None:
            print("Exporting stats to %s" % plot_file)
            ncol = 3
            nrow = int(np.ceil(float(len(lines)) / ncol))
            matplotlib.rcParams.update({"font.size": 5})
            color = cycle(plt.cm.tab10(np.linspace(0, 1, 10)))
            fig = plt.figure(1)
            plt.clf()
            for idx, (tmodes, stat, x, vals) in enumerate(lines):
                c = next(color)
                plt.subplot(nrow, ncol, idx + 1)
                for vali, vals_ in enumerate(vals.T):
                    c_ = c * (1.0 - float(vali) * 0.3)
                    valid = np.where(np.isfinite(vals_))
                    if len(valid) == 0:
                        continue
                    plt.plot(x[valid], vals_[valid], c=c_, linewidth=1)
                plt.ylabel(stat)
                plt.xlabel("epoch")
                plt.gca().yaxis.label.set_color(c[0:3] * 0.75)
                plt.legend(tmodes)
                gcolor = np.array(mcolors.to_rgba("lightgray"))
                plt.grid(
                    b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4
                )
                plt.grid(
                    b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2
                )
                plt.minorticks_on()

            plt.tight_layout()
            plt.show()
            fig.savefig(plot_file)

pytorch3D

class RayBundle(NamedTuple)

class RayBundle(NamedTuple):
    """
    RayBundle通过存储射线“origins”,沿投影射线参数化点,
    `directions` vectors and `lengths` at which the ray-points are sampled.
    此外,射线像素的XY位置(‘xys’)也被存储。
    请注意,“directions”不必标准化;
    它们在各自的一维坐标系中定义单位向量; 
    有关转换公式,请参见:func:“Ray_bundle_to_ray_points”的文档。
    """

    origins: torch.Tensor
    directions: torch.Tensor
    lengths: torch.Tensor
    xys: torch.Tensor
举报

相关推荐

0 条评论