目录
- model:class RadianceFieldRenderer
- stats object
- 学习率scheduler 设置
- 可视化设置
- 数据集
- 训练
- epoch循环
- class Stats 其他函数:
- pytorch3D
model:class RadianceFieldRenderer
class RadianceFieldRenderer(torch.nn.Module):
"""
Raybundler->class(orgin, direction, length, xys)
文件实现了 RadianceFieldRenderer类,集成torch.nn.Modulerendering的前向传播过程如下:
1.对于给定的输入 camera,rendering ray 通过“NeRFRaysampler” 产生。
2.在training mode下,射线是一组“射线是一组“n_rays_per_image”的图像网格的随机2D位置。
3.在evaluation mode下,光束对应的是整张图片的网格.光束被进一步分割为“chunk_size_test” 大小的块,以防止内存不足的错误.
渲染前向传递过程如下:
1) 对于给定的输入相机,渲染光线是用
`self._renderer['coarse']` 的 `NeRFRaysampler` 对象。
在训练模式下(`self.training==True`),光线是一组
图像网格的“n_rays_per_image”随机二维位置。
在评估模式下(`self.training==False`),光线对应
到完整的图像网格。光线进一步分裂为
`chunk_size_test`-大小的块,以防止内存不足错误。
2) 对于每个射线点,评估粗略的 `NeuralRadianceField` MLP。
指向此 MLP 的指针存储在 `self._implicit_function['coarse']` 中
3) 粗略的辐射场用`self._renderer['coarse']` 的 `EmissionAbsorptionNeRFRaymarcher` 对象。
4)粗raymarcher输出引导的概率分布用于精细渲染通道的重要性光线采样。这
`ProbabilisticRaysampler` 存储在 `self._renderer['fine'].raysampler` 中
实现重要性光线采样。
5) 类似于 2) `self._implicit_function['fine']` 中的精细 MLP
用占据和颜色标记射线点。
6) self._renderer['fine'].raymarcher` 生成最终的精细渲染。
7) 将精细渲染和粗糙渲染与地面实况输入图像进行比较与 PSNR 和 MSE 指标。
"""
def init
def __init__(
self,
image_size: Tuple[int, int],#高*宽
n_pts_per_ray: int,#在粗渲染过程中,沿着每条光线采样的点数
n_pts_per_ray_fine: int,#在精细渲染过程中,沿着每条光线采样的点数
n_rays_per_image: int,#训练时每张图片的光束数量
min_depth: float,#用于粗渲染的采样射线点的最小深度
max_depth: float,#用于粗渲染的采样射线点的最大深度
stratified: bool,#训练时分层每个光束点的深度
stratified_test: bool,#测试时分层每个光束点的深度
chunk_size_test: int,#每个成像射线块中的射线数,目前此处有些问题 应该是测试时有效?
n_harmonic_functions_xyz: int = 6,#谐波函数嵌入时的3D坐标 就是Positional encode的L
n_harmonic_functions_dir: int = 4,#谐波函数嵌入时的方向值 就是Positional encode的L
n_hidden_neurons_xyz: int = 256,#在MLP全连接层隐藏单元数量,该层接受3D点位置并输出具有中间特征的占用域
n_hidden_neurons_dir: int = 128,#在MLP的完全连接层中,接受中间特征和光线方向并输出亮度场(逐点颜色)的隐藏单元数。(最后一层)
n_layers_xyz: int = 8,##MLP的占用域的输出层数
append_xyz: Tuple[int] = (5,),#占用 MLP 的跳过层的索引列表。在评估跳过层之前,输入到 MLP 的张量附加到跳过层输入。
density_noise_std: float = 0.0,#将随机正态噪声的标准差加到占用MLP的输出上。仅在' self.training==True '时激活。
visualization: bool = False,#是否为可视化存储额外的输出
):
"""
Args:
image_size: The size of the rendered image (`[height, width]`).
n_pts_per_ray: The number of points sampled along each ray for the
coarse rendering pass.
n_pts_per_ray_fine: The number of points sampled along each ray for the
fine rendering pass.
n_rays_per_image: Number of Monte Carlo ray samples when training
(`self.training==True`).
min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
stratified: If `True`, stratifies (=randomly offsets) the depths
of each ray point during training (`self.training==True`).
stratified_test: If `True`, stratifies (=randomly offsets) the depths
of each ray point during evaluation (`self.training==False`).
chunk_size_test: The number of rays in each chunk of image rays.
Active only when `self.training==True`.
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
Prior to evaluating the skip layers, the tensor which was input to MLP
is appended to the skip layer input.
density_noise_std: The standard deviation of the random normal noise
added to the output of the occupancy MLP.
Active only when `self.training==True`.
visualization: whether to store extra output for visualization.
"""
super().__init__()
# The renderers and implicit functions are stored under the fine/coarse
# keys in ModuleDict PyTorch modules.
# 渲染器和隐式函数存储在fine/coarse下
# ModuleDict PyTorch 模块中的键
self._renderer = torch.nn.ModuleDict()##定义render(分coarse和fine) 两个key分别是coarse和fine
self._implicit_function = torch.nn.ModuleDict()##定义网络隐层函数(分coarse和fine), 结构和renderer一致,不过对应的是NeuralRadiencefunction
#self._implicit_function->n_harmonic_function_xyz_dir hidden_neurons_xyz_dir, n_layers, append_xyz
# Init the EA raymarcher used by both passes.
raymarcher = EmissionAbsorptionNeRFRaymarcher()##返回的是特征和权重
# Parse out image dimensions.
image_height, image_width = image_size#800 800
for render_pass in ("coarse", "fine"):
if render_pass == "coarse":
# Initialize the coarse raysampler.
#正常的均匀点采样
raysampler = NeRFRaysampler(
n_pts_per_ray=n_pts_per_ray,##沿每条光线的采样点数
min_depth=min_depth,
max_depth=max_depth,
stratified=stratified,#训练时分层每个光束点的深度
stratified_test=stratified_test,
n_rays_per_image=n_rays_per_image,
image_height=image_height,
image_width=image_width,
)
elif render_pass == "fine":
#重要性采样
# Initialize the fine raysampler.
raysampler = ProbabilisticRaysampler(
n_pts_per_ray=n_pts_per_ray_fine,
stratified=stratified,
stratified_test=stratified_test,
)
else:
raise ValueError(f"No such rendering pass {render_pass}")
# Initialize the fine/coarse renderer.初始化fine/coarse render
self._renderer[render_pass] = ImplicitRenderer(
raysampler=raysampler,
raymarcher=raymarcher,
)
# Instantiate the fine/coarse NeuralRadianceField module.网络实例化
self._implicit_function[render_pass] = NeuralRadianceField(
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
n_harmonic_functions_dir=n_harmonic_functions_dir,
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
n_hidden_neurons_dir=n_hidden_neurons_dir,
n_layers_xyz=n_layers_xyz,
append_xyz=append_xyz,
)
self._density_noise_std = density_noise_std
self._chunk_size_test = chunk_size_test
self._image_size = image_size
self.visualization = visualization
def precache_rays
def precache_rays(
self,
cache_cameras: List[CamerasBase], #precache ray的n个camera的列表
cache_camera_hashes: List[str], #每个相机独特标识符的列表
):
"""
预缓存从相机“cache_cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`cache_camera_hashes`。
缓存的光线被移动到 cpu 并存储在 `self._renderer['coarse']._ray_cache`。
缓存具有相同哈希的两个相机时引发 `ValueError`。
参数:
cache_cameras:预先缓存了光线的“N”个摄像机的列表。
cache_camera_hashes:每个唯一标识符的“N”个列表
Precaches the rays emitted from the list of cameras `cache_cameras`,
where each camera is uniquely identified with the corresponding hash
from `cache_camera_hashes`.
The cached rays are moved to cpu and stored in
`self._renderer['coarse']._ray_cache`.
Raises `ValueError` when caching two cameras with the same hash.
Args:
cache_cameras: A list of `N` cameras for which the rays are pre-cached.
cache_camera_hashes: A list of `N` unique identifiers for each
camera from `cameras`.
"""
self._renderer["coarse"].raysampler.precache_rays(##预缓存光线
cache_cameras,
cache_camera_hashes,
)
def _process_ray_chunk
def _process_ray_chunk(
self,
camera_hash: Optional[str],# pre-cached camera的唯一标识符
camera: CamerasBase,# 一批场景被渲染的cameara
image: torch.Tensor,# ground truth , shape(batch_size,,3)
chunk_idx: int, # 当前射线块的索引。
) -> dict:
"""
Samples and renders a chunk of rays.
Args:
camera_hash: A unique identifier of a pre-cached camera.
If `None`, the cache is not searched and the sampled rays are
calculated from scratch.
camera: A batch of cameras from which the scene is rendered.
image: A batch of corresponding ground truth images of shape
('batch_size', ·, ·, 3).
chunk_idx: The index of the currently rendered ray chunk.
Returns:
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
return :
out: `dict` 包含渲染的输出:
`rgb_coarse`:粗略渲染过程的结果。
`rgb_fine`:精细渲染通道的结果。
`rgb_gt`:对应的ground-truth RGB值。
"""
# Initialize the outputs of the coarse rendering to None.
coarse_ray_bundle = None##orgin direction depth xy二维坐标
coarse_weights = None
# First evaluate the coarse rendering pass, then the fine one.
for renderer_pass in ("coarse", "fine"):
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass](
cameras=camera,
volumetric_function=self._implicit_function[renderer_pass],
chunksize=self._chunk_size_test,
chunk_idx=chunk_idx,
density_noise_std=(self._density_noise_std if self.training else 0.0),
input_ray_bundle=coarse_ray_bundle,
ray_weights=coarse_weights,
camera_hash=camera_hash,
)
if renderer_pass == "coarse":
rgb_coarse = rgb
# Store the weights and the rays of the first rendering pass
# for the ensuing importance ray-sampling of the fine render.
#存储第一个渲染通道的权重和光线,用于精细渲染的后续重要性光线采样。
coarse_ray_bundle = ray_bundle_out##第一次粗渲染的结果保存用于精细渲染
coarse_weights = weights
if image is not None:
# Sample the ground truth images at the xy locations of the
# rendering ray pixels.
# 在xy位置对ground truth图像进行采样,渲染光线像素。
rgb_gt = sample_images_at_mc_locs(
image[..., :3][None],
ray_bundle_out.xys,
)
else:
rgb_gt = None
elif renderer_pass == "fine":
rgb_fine = rgb
else:
raise ValueError(f"No such rendering pass {renderer_pass}")
out = {"rgb_fine": rgb_fine, "rgb_coarse": rgb_coarse, "rgb_gt": rgb_gt}
if self.visualization:
# Store the coarse rays/weights only for visualization purposes.
# 存储粗射线/权重仅用于可视化目的。
out["coarse_ray_bundle"] = type(coarse_ray_bundle)(
*[v.detach().cpu() for k, v in coarse_ray_bundle._asdict().items()]
)
out["coarse_weights"] = coarse_weights.detach().cpu()
##切断梯度
return out
def forward
def forward(
self,
camera_hash: Optional[str],
camera: CamerasBase,
image: torch.Tensor,
) -> Tuple[dict, dict]:
"""
从输入的“camera”的角度执行辐射场的粗略和精细的渲染传递。
之后,两种渲染方式都与输入的地面真实“图像”相比较。
通过评估峰值信噪比和均方误差.
The rendering result depends on the `self.training` flag:
- In the training mode (`self.training==True`),
该函数renders图像射线的随机子集(MonteCarlo rendering)。
- In evaluation mode (`self.training==False`),
该函数renders整个图像。为了防止内存不足的错误,
当 `self.training==False` 时,对光线进行采样和渲染
批量大小为“chunksize”。
参数:
camera_hash:预缓存相机的唯一标识符。
如果为“None”,则不搜索缓存并且采样光线
从头计算。
camera:渲染场景的一批相机。
image: 一批对应形状为
('batch_size', ·, ·, 3)的ground truth图像
Returns:
out: `dict` 包含渲染的输出:
`rgb_coarse`:粗略渲染过程的结果。
`rgb_fine`:精细渲染通道的结果。
`rgb_gt`:对应的ground-truth RGB值。
The shape of `rgb_coarse`, `rgb_fine`, `rgb_gt` depends on the
`self.training` flag:
If `==True`, all 3 tensors are of shape
`(batch_size, n_rays_per_image, 3)` and contain the result
of the Monte Carlo training rendering pass.
If `==False`, all 3 tensors are of shape
`(batch_size, image_size[0], image_size[1], 3)` and contain
the result of the full image rendering pass.
metrics: `dict` 包含比较细和粗渲染到实际情况的误差度量:
`mse_coarse`:粗略渲染和渲染之间的均方误差
输入`图像`
`mse_fine`:精细渲染和渲染之间的均方误差
输入`图像`
`psnr_coarse`:粗略渲染和粗糙渲染之间的峰值信噪比
输入`图像`
`psnr_fine`:精细渲染和精细渲染之间的峰值信噪比
输入`图像`
"""
if not self.training:
# Full evaluation pass.
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(#获取chunk数即块数
self._chunk_size_test,
camera.R.shape[0],
)
else:
# 如果在测试的话完整的渲染.
n_chunks = 1
# Process the chunks of rays.
'''return :
out: `dict` 包含渲染的输出:
`rgb_coarse`:粗略渲染过程的结果。
`rgb_fine`:精细渲染通道的结果。
`rgb_gt`:对应的ground-truth RGB值。'''
chunk_outputs = [
self._process_ray_chunk(
camera_hash,
camera,
image,
chunk_idx,
)
for chunk_idx in range(n_chunks)
]
if not self.training:
# For a full render pass concatenate the output chunks,
# and reshape to image size.
# 对于完整的渲染过程,连接输出块,并重塑为图像大小。
out = {
k: torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
}
else:
out = chunk_outputs[0]
# Calc the error metrics.
metrics = {}
if image is not None:
for render_pass in ("coarse", "fine"):
for metric_name, metric_fun in zip(
("mse", "psnr"), (calc_mse, calc_psnr)
):
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
out["rgb_" + render_pass][..., :3],
out["rgb_gt"][..., :3],
)
return out, metrics
stats object
# Init the stats object.
if stats is None:
stats = Stats(
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
)
class Stats:
def init
class Stats:
"""
统计日志对象,用于收集在PyTorch中训练深度网络的统计数据。
Example:
```
# 初始化统计数据结构,用于记录统计数据“objective”和“top1e”.
stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=neural network)
dataloader = init_dataloader() # init a dataloader
for epoch in range(10):
# start of epoch -> call new_epoch
stats.new_epoch()
# Iterate over batches.
for batch in dataloader:
# 运行模型并将其保存到输出变量的“output”中。
output = network(batch)
# stats.update() 自动解析 the 'objective' and 'top1e'
# from the "output" dict 并把这个存储在数据库里。
stats.update(output)
stats.print() # 打印给定epoch的平均值
# Stores the training plots into '/tmp/epoch_stats.pdf'
# 并绘制成运行在localhost上的Visdom服务器(如果运行的话)。
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
```
"""
def __init__(
self,
log_vars: List[str],
verbose: bool = False,
epoch: int = -1,
plot_file: Optional[str] = None,
) -> None:
"""
Args:
log_vars: 要记录的变量名列表。
verbose: 打印状态信息。
epoch: The initial epoch of the object.
plot_file: The path to the file that will hold the training plots.
"""
self.verbose = verbose
self.log_vars = log_vars
self.plot_file = plot_file
self.hard_reset(epoch=epoch)
def hard_reset(self, epoch: int = -1) -> None
self.hard_reset(epoch=epoch)
def hard_reset(self, epoch: int = -1) -> None:
"""
删除所有记录的数据。
"""
self._epoch_start = None
self.epoch = epoch
if self.verbose:
print("stats: epoch %d - hard reset" % self.epoch)
self.stats = {}
self.reset()
def reset(self) -> None
self.reset()
def reset(self) -> None:
"""
在一个epoch 之前调用以清除当前的epoch缓冲区。
"""
stat_sets = list(self.stats.keys())
if self.verbose:
print("stats: epoch %d - reset" % self.epoch)
self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets:
for stat in self.stats[stat_set]:
self.stats[stat_set][stat].reset()
# 设置一个新的时间戳。
self._epoch_start = time.time()
学习率scheduler 设置
learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
# 学习速率调度器设置.
# 在原始代码之后, we use exponential decay of the#指数衰减
# learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
def lr_lambda(epoch): ##定义学习率指数衰减
return cfg.optimizer.lr_scheduler_gamma ** (
epoch / cfg.optimizer.lr_scheduler_step_size
)
# The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
# 利用学习率调度器实现lr——lambda的指数衰减学习率
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
)
可视化设置
# 初始化缓存,以存储可视化所需的变量
visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)
# Init the visualization visdom env.
if cfg.visualization.visdom:
viz = Visdom(
server=cfg.visualization.visdom_server,
port=cfg.visualization.visdom_port,
use_incoming_socket=False,
)
else:
viz = None
数据集
1、加载training/validation数据。
# Load the training/validation data.
train_dataset, val_dataset, _ = get_nerf_datasets(
dataset_name=cfg.data.dataset_name,
image_size=cfg.data.image_size, # 加载初始化中的训练集和验证集
)
def get_nerf_datasets
##返回图像, 相机参数, 相机编号3个组成的数据结构
def get_nerf_datasets(
dataset_name: str, # 'lego | fern' 给定具体的场景
image_size: Tuple[int, int], #图像的尺寸(height,width)
data_root: str = DEFAULT_DATA_ROOT, #数据的网络连接
autodownload: bool = True, #根据网络连接对数据直接进行下载
) -> Tuple[Dataset, Dataset, Dataset]: ##返回训练集,验证集,测试集
"""
获取使用“dataset_name”参数指定的数据集的training and validation DataSet对象
Args:
dataset_name: 要加载的数据集的名称。
image_size: 表示加载的数据集图像大小的元组(高度、宽度)。
data_root: 存储数据的根文件夹。
autodownload: 自动下载数据集文件,以防它们丢失。
Returns:
train_dataset: The training dataset object.
val_dataset: The validation dataset object.
test_dataset: The testing dataset object.
"""
if dataset_name not in ALL_DATASETS:
raise ValueError(f"'{dataset_name}'' does not refer to a known dataset.")
print(f"Loading dataset {dataset_name}, image size={str(image_size)} ...")
cameras_path = os.path.join(data_root, dataset_name + ".pth")#cameras_path
image_path = cameras_path.replace(".pth", ".png")#image_path
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
# Automatically download the data files if missing.自动下载缺失的数据
download_data((dataset_name,), data_root=data_root)
train_data = torch.load(cameras_path)##获得训练数据
n_cameras = train_data["cameras"]["R"].shape[0]##获得相机个数 400
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS #python image 库中设定好的image像素的最大值
Image.MAX_IMAGE_PIXELS = None # The dataset image is very large ...
images = torch.FloatTensor(np.array(Image.open(image_path))) / 255.0
##求梯度只针对于浮点型tensor,8bit原始图像归一化到[0,1]的(torch.FloatTensor)
images = torch.stack(torch.chunk(images, n_cameras, dim=0))[..., :3]##torch.chunk分块
##[H, W, 3] (本例取H=W=400)
Image.MAX_IMAGE_PIXELS = _image_max_image_pixels
#确定图像的放缩比例 并调整图像
scale_factors = [s_new / s for s, s_new in zip(images.shape[1:3], image_size)]
if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
raise ValueError(
"Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument."
)
scale_factor = sum(scale_factors) * 0.5
if scale_factor != 1.0:
print(f"Rescaling dataset (factor={scale_factor})")
images = torch.nn.functional.interpolate(
images.permute(0, 3, 1, 2),
size=tuple(image_size),
mode="bilinear",
).permute(0, 2, 3, 1)##通过插值进行缩放
#camera是pytorch3d.renderer.cameras.PerspectiveCameras的实例
#主要用到的是选择矩阵旋转R,平移矩阵T,焦距和主视点
cameras = [
PerspectiveCameras(
**{k: v[cami][None] for k, v in train_data["cameras"].items()}
).to("cpu") ##提取train_data信息到cameras中
for cami in range(n_cameras)##共建立n_cameras个相机实例
]
##将train_data划分为train, val, test
train_idx, val_idx, test_idx = train_data["split"]
train_dataset, val_dataset, test_dataset = [
ListDataset(
[
{"image": images[i], "camera": cameras[i], "camera_idx": int(i)}
for i in idx
]
)
for idx in [train_idx, val_idx, test_idx]
]
return train_dataset, val_dataset, test_dataset
class ListDataset(Dataset)
class ListDataset(Dataset):
"""
由entries(条目)列表组成的简单数据集。
"""
def __init__(self, entries: List) -> None:
"""
Args:
entries: The list of dataset entries.
"""
self._entries = entries
def __len__(
self,
) -> int:
return len(self._entries)
def __getitem__(self, index):
return self._entries[index] ##直接获取该索引对应的列表内容
debug
train_data
images = torch.stack(torch.chunk(images, n_cameras, dim=0))[…, :3]##torch.chunk分块
scale_factors
2、 预缓存投影光线
if cfg.data.precache_rays:
# Precache the projection rays.预缓存投影光线
model.eval() # 不启用 batchnorm and dropout
with torch.no_grad():
for dataset in (train_dataset, val_dataset):
cache_cameras = [e["camera"].to(device) for e in dataset]
cache_camera_hashes = [e["camera_idx"] for e in dataset]
# 把 dataset中的camera和camera_idx拿出来进行缓存
model.precache_rays(cache_cameras, cache_camera_hashes)
model.precache_rays
model.precache_rays(cache_cameras, cache_camera_hashes)
def precache_rays(
self,
cache_cameras: List[CamerasBase], #precache ray的n个camera的列表
cache_camera_hashes: List[str], #每个相机独特标识符的列表
):
"""
预缓存从相机“cache_cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`cache_camera_hashes`。
缓存的光线被移动到 cpu 并存储在 `self._renderer['coarse']._ray_cache`。
缓存具有相同哈希的两个相机时引发 `ValueError`。
参数:
cache_cameras:预先缓存了光线的“N”个摄像机的列表。
cache_camera_hashes:每个唯一标识符的“N”个列表
"""
self._renderer["coarse"].raysampler.precache_rays(##预缓存光线
cache_cameras,
cache_camera_hashes,
)
NeRFRaysampler
class NeRFRaysampler(torch.nn.Module)
def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
def precache_rays
self._renderer[“coarse”].raysampler.precache_rays
def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
"""
预缓存从相机“cameras”列表中发出的光线,其中每个相机都用相应的哈希唯一标识来自`camera_hashes`。
缓存的光线被移动到 cpu 并存储在 `self._ray_cache` 中,缓存具有相同哈希的两个相机时引发 `ValueError`。
Args::
cameras:预先缓存了光线的“N”个相机的列表。
camera_hashes:每个的“N”个唯一标识符的列表,来自“cameras”的相机。
"""
print(f"Precaching {len(cameras)} ray bundles ...")
full_chunksize = ( #4096000
self._grid_raysampler._xy_grid.numel()
// 2
* self._grid_raysampler._n_pts_per_ray
)##每个chunk最多预缓存的光线数
if self.get_n_chunks(full_chunksize, 1) != 1:##只能是一个chunk来预缓存光线
raise ValueError("There has to be one chunk for precaching rays!")
for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
ray_bundle = self.forward(
camera,
caching=True,
chunksize=full_chunksize,
)
if camera_hash in self._ray_cache:
raise ValueError("There are redundant cameras!")
self._ray_cache[camera_hash] = RayBundle( ##将光线预缓存到列表cpu
*[v.to("cpu").detach() for v in ray_bundle]
)
self._print_precaching_progress(camera_i, len(cameras))
print("")
debug
self
for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
def get_n_chunks
self.get_n_chunks(full_chunksize, 1)
def get_n_chunks(self, chunksize: int, batch_size: int):
"""
返回“chunksize”大小的光线采样器的光线块的总数。
参数:
chunksize:每个块的光线数量。
batch_size:光线采样器的批次大小。
return:
n_chunks:块的总数。
"""
return int(
math.ceil(#向上取整 numel->返回元素个数
(self._grid_raysampler._xy_grid.numel() * 0.5 * batch_size) / chunksize
)
)
def forword
ray_bundle = self.forward(
camera,
caching=True,
chunksize=full_chunksize,
)
def forward(
self,
cameras: CamerasBase,
chunksize: int = None,
chunk_idx: int = 0,
camera_hash: str = None,
caching: bool = False,
**kwargs,
) -> RayBundle:
"""
Args:
cameras: 发出射线的一批“batch_size”照相机。
chunksize: 每块射线的数量。只在下列情况下激活 `self.training==False`.
chunk_idx: 射线块的索引。号码必须在
`[0, self.get_n_chunks(chunksize, batch_size)-1]`.
只在下列情况下激活 `self.training==False`.
camera_hash: 预缓存相机的唯一标识符。 If `None`,
缓存不被搜索,光线是从头开始计算的。
caching: If `True`, 激活缓存模式,该模式返回应该存储到缓存中的“RayBundle”。
Returns:
一个名为“RayBundle”的元组,具有以下字段:
origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
在世界坐标中表示射线原点的位置。
directions: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
表示世界坐标中每一条射线的方向。
lengths: A tensor of shape
`(batch_size, n_rays_per_image, n_pts_per_ray)`
包含世界单位中每条射线的z坐标(=depth).
xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)`
包含每个射线的2D图像坐标。
"""
batch_size = cameras.R.shape[0] # pyre-ignore 1
device = cameras.device
if (camera_hash is None) and (not caching) and self.training:
# Sample random rays from scratch.从头开始采集光线
ray_bundle = self._mc_raysampler(cameras)
ray_bundle = self._normalize_raybundle(ray_bundle)
else:
if camera_hash is not None:
# 我们从缓存中取回相机的情况.
if batch_size != 1:
raise NotImplementedError(
"Ray caching works only for batches with a single camera!"
)
full_ray_bundle = self._ray_cache[camera_hash]#从光线缓存中取出光线
else:
# 我们从零开始生成一个完整的射线网格。
full_ray_bundle = self._grid_raysampler(cameras)
full_ray_bundle = self._normalize_raybundle(full_ray_bundle)
n_pixels = full_ray_bundle.directions.shape[:-1].numel() #640000
if self.training:
# During training we randomly subsample rays.对光线进行二次采样
sel_rays = torch.randperm(n_pixels, device=device)[
: self._mc_raysampler._n_rays_per_image
]
else:
# 如果我们进行测试,我们只接受所请求的块.直接取chunk
if chunksize is None:
chunksize = n_pixels * batch_size
start = chunk_idx * chunksize * batch_size #0
end = min(start + chunksize, n_pixels) #640000
sel_rays = torch.arange(
start,
end,
dtype=torch.long,
device=full_ray_bundle.lengths.device,
)
# 把整束射线中的“sel_rays”取出来。
ray_bundle = RayBundle(
*[
v.view(n_pixels, -1)[sel_rays]
.view(batch_size, sel_rays.numel() // batch_size, -1)
.to(device)
for v in full_ray_bundle
]
)
if (
(self._stratified and self.training)
or (self._stratified_test and not self.training)
) and not caching: # 请确保缓存时不要分层!
ray_bundle = self._stratify_ray_bundle(ray_bundle)
return ray_bundle
debug
self:
NeRFRaysampler(
(_grid_raysampler): NDCMultinomialRaysampler()
(_mc_raysampler): MonteCarloRaysampler()
)
full_ray_bundle = self._grid_raysampler(cameras)
def _normalize_raybundle
ray_bundle = self._normalize_raybundle(ray_bundle)
def _normalize_raybundle(self, ray_bundle: RayBundle):
"""
将输入“RayBundle”的光线方向标准化为单位范数。
"""
ray_bundle = ray_bundle._replace(
directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
)
return ray_bundle
def _print_precaching_progress
self._print_precaching_progress(camera_i, len(cameras))
def _print_precaching_progress(self, i, total, bar_len=30):
"""
打印光线缓存的进度条
"""
position = round((i + 1) / total * bar_len)
pbar = "[" + "█" * position + " " * (bar_len - position) + "]"
print(pbar, end="\r")
3、train_dataloader、val_dataloader
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=0,
collate_fn=trivial_collate,
)
# 验证数据器只是一个无休止的随机样本流。
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=0,
collate_fn=trivial_collate,
sampler=torch.utils.data.RandomSampler(
val_dataset,
replacement=True,
num_samples=cfg.optimizer.max_epochs,
),
)
训练
# 接下来就是正常的训练过程
# Set the model to the training mode.
model.train()
epoch循环
# Run the main training loop.
for epoch in range(start_epoch, cfg.optimizer.max_epochs):
stats.new_epoch() # Init a new epoch.建立一个新的epoch训练
batch循环
1、提取数据
for iteration, batch in enumerate(train_dataloader):
image, camera, camera_idx = batch[0].values() # 第一个epoch对应的值
image = image.to(device)
camera = camera.to(device)
optimizer.zero_grad()
2、运行模型
执行RadianceFieldRenderer模型的forword文件
代码在上面
# Run the forward pass of the model.
nerf_out, metrics = model(
camera_idx if cfg.data.precache_rays else None,
camera,
image,
)
model. self
RadianceFieldRenderer(
(_renderer): ModuleDict(
(coarse): ImplicitRenderer(
(raysampler): NeRFRaysampler(
(_grid_raysampler): NDCMultinomialRaysampler()
(_mc_raysampler): MonteCarloRaysampler()
)
(raymarcher): EmissionAbsorptionNeRFRaymarcher()
)
(fine): ImplicitRenderer(
(raysampler): ProbabilisticRaysampler()
(raymarcher): EmissionAbsorptionNeRFRaymarcher()
)
)
(_implicit_function): ModuleDict(
(coarse): NeuralRadianceField(
(harmonic_embedding_xyz): HarmonicEmbedding()
(harmonic_embedding_dir): HarmonicEmbedding()
(mlp_xyz): MLPWithInputSkips(
(mlp): ModuleList(
(0): Sequential(
(0): Linear(in_features=63, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(2): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(3): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(4): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(5): Sequential(
(0): Linear(in_features=319, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(6): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(7): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
)
)
(intermediate_linear): Linear(in_features=256, out_features=256, bias=True)
(density_layer): Linear(in_features=256, out_features=1, bias=True)
(color_layer): Sequential(
(0): LinearWithRepeat()
(1): ReLU(inplace=True)
(2): Linear(in_features=128, out_features=3, bias=True)
(3): Sigmoid()
)
)
(fine): NeuralRadianceField(
(harmonic_embedding_xyz): HarmonicEmbedding()
(harmonic_embedding_dir): HarmonicEmbedding()
(mlp_xyz): MLPWithInputSkips(
(mlp): ModuleList(
(0): Sequential(
(0): Linear(in_features=63, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(1): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(2): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(3): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(4): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(5): Sequential(
(0): Linear(in_features=319, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(6): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
(7): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): ReLU(inplace=True)
)
)
)
(intermediate_linear): Linear(in_features=256, out_features=256, bias=True)
(density_layer): Linear(in_features=256, out_features=1, bias=True)
(color_layer): Sequential(
(0): LinearWithRepeat()
(1): ReLU(inplace=True)
(2): Linear(in_features=128, out_features=3, bias=True)
(3): Sigmoid()
)
)
)
)
chunk_outputs
metrics
coarse
fine
3、loss
# The loss is a sum of coarse and fine MSEs metrics指标
loss = metrics["mse_coarse"] + metrics["mse_fine"]
# Take the training step.
loss.backward()
optimizer.step()
loss:
tensor(0.2945, device='cuda:0', grad_fn=<AddBackward0>)
4、stats
# Update stats with the current metrics.
stats.update(
{"loss": float(loss), **metrics},
stat_set="train",
)
if iteration % cfg.stats_print_interval == 0:
stats.print(stat_set="train") ##每隔10print出来一次
stats
{'train':
{'loss': <nerf.stats.AverageMeter object at 0x7f394e0f2970>,
'mse_coarse': <nerf.stats.AverageMeter object at 0x7f394e104b50>,
'mse_fine': <nerf.stats.AverageMeter object at 0x7f394e0f2850>,
'psnr_coarse': <nerf.stats.AverageMeter object at 0x7f394e0f2070>,
'psnr_fine': <nerf.stats.AverageMeter object at 0x7f394e114730>,
'sec/it': <nerf.stats.AverageMeter object at 0x7f394e129310>}
}
5、可视化更新
# Update the visualization cache.
if viz is not None: ##可视化更新
visuals_cache.append(
{
"camera": camera.cpu(),
"camera_idx": camera_idx,
"image": image.cpu().detach(), ##detach不需要梯度
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
}
)
调整学习速度。
# Adjust the learning rate.
lr_scheduler.step()
验证
# Validation
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
# Sample a validation camera/image.
val_batch = next(val_dataloader.__iter__())
val_image, val_camera, camera_idx = val_batch[0].values()
val_image = val_image.to(device)
val_camera = val_camera.to(device)
# 激活模型的val模式(让我们完成一个完整的渲染传递)。
model.eval()
with torch.no_grad():
val_nerf_out, val_metrics = model(
camera_idx if cfg.data.precache_rays else None,
val_camera,
val_image,
)
# Update stats with the validation metrics. ##更新指标 即损失函数的加权
stats.update(val_metrics, stat_set="val")
stats.print(stat_set="val")
if viz is not None:
# Plot that loss curves into visdom.
stats.plot_stats(
viz=viz,
visdom_env=cfg.visualization.visdom_env,
plot_file=None,
)
# 可视化中间结果。
visualize_nerf_outputs(
val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env
)
# Set the model back to train mode.
model.train()
Checkpoint
# Checkpoint.
if (
epoch % cfg.checkpoint_epoch_interval == 0
and len(cfg.checkpoint_path) > 0
and epoch > 0
):
print(f"Storing checkpoint {checkpoint_path}.")
data_to_store = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"stats": pickle.dumps(stats),
}
torch.save(data_to_store, checkpoint_path) ##保存训练模型参数数据到checkpoint
class Stats 其他函数:
def new_epoch
def new_epoch(self) -> None:
"""
Initializes a new epoch.
"""
if self.verbose:
print("stats: new epoch %d" % (self.epoch + 1))
self.epoch += 1 # increase epoch counter
self.reset() # zero the stats
update
# Update stats with the current metrics.
stats.update(
{"loss": float(loss), **metrics},
stat_set="train",
)
def update(self, preds: dict, stat_set: str = "train") -> None:
"""
Update the internal logs with metrics of a training step.
每个指标都存储为AverageMeter的实例。
Args:
preds: 要添加到日志中的值的Dict
stat_set: 待更新的统计数据集 (e.g. "train", "val").
"""
if self.epoch == -1: # uninitialized
warnings.warn(
"self.epoch==-1 means uninitialized stats structure"
" -> new_epoch() called"
)
self.new_epoch()
if stat_set not in self.stats:
self.stats[stat_set] = {}
self.it[stat_set] = -1
self.it[stat_set] += 1
epoch = self.epoch
it = self.it[stat_set]
for stat in self.log_vars:
if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter()
if stat == "sec/it": # compute speed
elapsed = time.time() - self._epoch_start
time_per_it = float(elapsed) / float(it + 1)
val = time_per_it
else:
if stat in preds:
val = self._gather_value(preds[stat])
else:
val = None
if val is not None:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def print(self, max_it: Optional[int] = None, stat_set: str = "train") -> None:
"""
Print the current values of all stored stats.
Args:
max_it: Maximum iteration number to be displayed.
If None, the maximum iteration number is not displayed.
stat_set: The set of statistics to be printed.
"""
epoch = self.epoch
stats = self.stats
str_out = ""
it = self.it[stat_set]
stat_str = ""
stats_print = sorted(stats[stat_set].keys())
for stat in stats_print:
if stats[stat_set][stat].count == 0:
continue
stat_str += " {0:.12}: {1:1.3f} |".format(stat, stats[stat_set][stat].avg)
head_str = f"[{stat_set}] | epoch {epoch} | it {it}"
if max_it:
head_str += f"/ {max_it}"
str_out = f"{head_str} | {stat_str}"
print(str_out)
def _gather_value(self, val):
if isinstance(val, float):
pass
else:
val = val.data.cpu().numpy()
val = float(val.sum())
return val
def plot_stats(
self,
viz: Visdom = None,
visdom_env: Optional[str] = None,
plot_file: Optional[str] = None,
) -> None:
"""
Plot the line charts of the history of the stats.
Args:
viz: The Visdom object holding the connection to a Visdom server.
visdom_env: The visdom environment for storing the graphs.
plot_file: The path to a file with training plots.
"""
stat_sets = list(self.stats.keys())
if viz is None:
withvisdom = False
elif not viz.check_connection():
warnings.warn("Cannot connect to the visdom server! Skipping visdom plots.")
withvisdom = False
else:
withvisdom = True
lines = []
for stat in self.log_vars:
vals = []
stat_sets_now = []
for stat_set in stat_sets:
val = self.stats[stat_set][stat].get_epoch_averages()
if val is None:
continue
else:
val = np.array(val).reshape(-1)
stat_sets_now.append(stat_set)
vals.append(val)
if len(vals) == 0:
continue
vals = np.stack(vals, axis=1)
x = np.arange(vals.shape[0])
lines.append((stat_sets_now, stat, x, vals))
if withvisdom:
for tmodes, stat, x, vals in lines:
title = "%s" % stat
opts = {"title": title, "legend": list(tmodes)}
for i, (tmode, val) in enumerate(zip(tmodes, vals.T)):
update = "append" if i > 0 else None
valid = np.where(np.isfinite(val))
if len(valid) == 0:
continue
viz.line(
Y=val[valid],
X=x[valid],
env=visdom_env,
opts=opts,
win=f"stat_plot_{title}",
name=tmode,
update=update,
)
if plot_file is None:
plot_file = self.plot_file
if plot_file is not None:
print("Exporting stats to %s" % plot_file)
ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5})
color = cycle(plt.cm.tab10(np.linspace(0, 1, 10)))
fig = plt.figure(1)
plt.clf()
for idx, (tmodes, stat, x, vals) in enumerate(lines):
c = next(color)
plt.subplot(nrow, ncol, idx + 1)
for vali, vals_ in enumerate(vals.T):
c_ = c * (1.0 - float(vali) * 0.3)
valid = np.where(np.isfinite(vals_))
if len(valid) == 0:
continue
plt.plot(x[valid], vals_[valid], c=c_, linewidth=1)
plt.ylabel(stat)
plt.xlabel("epoch")
plt.gca().yaxis.label.set_color(c[0:3] * 0.75)
plt.legend(tmodes)
gcolor = np.array(mcolors.to_rgba("lightgray"))
plt.grid(
b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4
)
plt.grid(
b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2
)
plt.minorticks_on()
plt.tight_layout()
plt.show()
fig.savefig(plot_file)
pytorch3D
class RayBundle(NamedTuple)
class RayBundle(NamedTuple):
"""
RayBundle通过存储射线“origins”,沿投影射线参数化点,
`directions` vectors and `lengths` at which the ray-points are sampled.
此外,射线像素的XY位置(‘xys’)也被存储。
请注意,“directions”不必标准化;
它们在各自的一维坐标系中定义单位向量;
有关转换公式,请参见:func:“Ray_bundle_to_ray_points”的文档。
"""
origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor