在做图像处理时,有时需要分块调整图像亮度或者颜色,为了保持空间的平滑性,在块间会做平滑过渡,就是空域插值,同时,值域上,一般是1维LUT,往往也会进行插值,这样就会有3维插值。HDRNet就是一个典型的3D插值,空域划分成了16x16,值域上划分成8个区间,作者提供的代码是编译成了so,把插值的过程封装成库了,调用时也依赖库编译的环境,通用性并不好。其实,pytorch里提供了一个函数,grid_sample函数可以实现3D插值,可以直接调用这个函数即可实现HDRNet里的slice操作,当然,可能还是稍微有点区别。
代码具体如下:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class Slice(nn.Module):
def __init__(self):
super(Slice, self).__init__()
def forward(self, grid, guidemap, align_corners=True):
N, _, H, W = guidemap.shape
hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
if torch.cuda.is_available():
hg = hg.type(torch.cuda.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (H - 1) * 2 - 1
wg = wg.type(torch.cuda.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (W - 1) * 2 - 1
else:
hg = hg.type(torch.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (H - 1) * 2 - 1
wg = wg.type(torch.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (W - 1) * 2 - 1
guidemap = guidemap * 2 - 1
# guidemap = guidemap.permute(0, 2, 3, 1).contiguous() # 深拷贝
guidemap = guidemap.permute(0, 2, 3, 1)
guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1)
coeff = F.grid_sample(grid, guidemap_guide, padding_mode='border', align_corners=align_corners)
return coeff.squeeze(2)
def TestVal1():
slicef = Slice()
input = torch.zeros((1, 1, 2, 2))
input[0, 0, 0, 0] = 0.25
input[0, 0, 0, 1] = 0.5
input[0, 0, 1, 0] = 0.75
input[0, 0, 1, 1] = 1.0
grid = torch.zeros((1, 1, 5, 1, 1))
grid[0, 0, 0, 0, 0] = 0.0
grid[0, 0, 1, 0, 0] = 0.25
grid[0, 0, 2, 0, 0] = 0.5
grid[0, 0, 3, 0, 0] = 0.75
grid[0, 0, 4, 0, 0] = 1.0
coeff = slicef(grid, input, True)
print(coeff)
def TestPos1():
slicef = Slice()
input = torch.zeros((1, 1, 3, 5))
input[0, 0, 0, 0] = 1.0
input[0, 0, 0, 1] = 1.0
input[0, 0, 0, 2] = 1.0
input[0, 0, 0, 3] = 1.0
input[0, 0, 0, 4] = 1.0
input[0, 0, 1, 0] = 1.0
input[0, 0, 1, 1] = 1.0
input[0, 0, 1, 2] = 1.0
input[0, 0, 1, 3] = 1.0
input[0, 0, 1, 4] = 1.0
input[0, 0, 2, 0] = 1.0
input[0, 0, 2, 1] = 1.0
input[0, 0, 2, 2] = 1.0
input[0, 0, 2, 3] = 1.0
input[0, 0, 2, 4] = 1.0
grid = torch.zeros((1, 1, 3, 2, 3))
grid[0, 0, 0, 0, 0] = 0.0
grid[0, 0, 1, 0, 0] = 0.5
grid[0, 0, 2, 0, 0] = 1.0
grid[0, 0, 0, 0, 1] = 0.0
grid[0, 0, 1, 0, 1] = 0.3
grid[0, 0, 2, 0, 1] = 0.6
grid[0, 0, 0, 0, 2] = 0.0
grid[0, 0, 1, 0, 2] = 0.4
grid[0, 0, 2, 0, 2] = 0.8
grid[0, 0, 0, 1, 0] = 0.0
grid[0, 0, 1, 1, 0] = 0.4
grid[0, 0, 2, 1, 0] = 0.8
grid[0, 0, 0, 1, 1] = 0.0
grid[0, 0, 1, 1, 1] = 0.3
grid[0, 0, 2, 1, 1] = 0.6
grid[0, 0, 0, 1, 2] = 0.0
grid[0, 0, 1, 1, 2] = 0.5
grid[0, 0, 2, 1, 2] = 1.0
coeff = slicef(grid, input, True)
print(coeff)
def Test2():
slicef = Slice()
input = torch.zeros((1, 1, 2, 2))
input[0, 0, 0, 0] = 1/16
input[0, 0, 0, 1] = 0.5
input[0, 0, 1, 0] = 0.75
input[0, 0, 1, 1] = 1.0
grid = torch.zeros((1, 1, 4, 1, 1))
grid[0, 0, 0, 0, 0] = 1/8
grid[0, 0, 1, 0, 0] = 3/8
grid[0, 0, 2, 0, 0] = 5/8
grid[0, 0, 3, 0, 0] = 7/8
coeff = slicef(grid, input, False)
print(coeff)
if __name__ == '__main__':
TestVal1()
TestPos1()
Test2()
其中,grid为网格的划分,是5D的,第一为batchsize,即N,第二维是输出系数大小,比如HDRnet里的12,第三维为值域的划分,比如HDRNet里的8,最后两维为空域划分的HW,比如HDRNet里的16x16,HDRNet里grid对应的大小应该为grid[N, 12, 8, 16, 16]。
grid_sample函数中比较重要的参数是align_corners,当align_corners=True时,padding_mode参数没什么意义,当align_corners=False时,padding_mode就有意义了。当align_corners=True时,一般3D划分的数目需要为奇数,否则划分的块或者段,并不是N,而是N-1,如下图所示:
这是空域上划分为3x4个块,值域上划分成4段,但grid的参数对应的是5x4x5,即grid为grid[N, 1, 5, 4, 5]。每个块的中心在各个块交点上。
当align_corners=False时,如下图所示:
可以看到,其实划分的块或段数目是一样的,但每个块的中心位置变了,不是边角了,而是块的中心了, 这时grid为grid[N, 1, 4, 3, 4],可以发现维度少了一个。不过,此时边界和四个角上的用的输出系数和邻近的中点的,就是边界处理方式不一样,所以此时,padding_mode就有意义了,代码里默认写成了border,就是用边界的,而不是0,grid_sample函数里,默认参数时zeros。HDRNet里应该是用的align_corners=False,padding_mode为border的情况。
当然,需要说明下,grid_sample函数在早期的版本1.3.0之前是没有align_corners参数的,用的是align_corners=True的情况,需要特别注意下。当然,这个函数还有改进的地方,就是空域上使用align_corners=False的情况比较好,但值域上,可能更多是需要使用align_corners=True的情况,要是能将二者区分就更好了。
代码里写了几个例子,TestVal1()和TestPos1都是测试align_corners=True的情况,分别测试的值域和空域,TestVal1里,空域只有1x1,就是没有划分,只需看值域,值域分成了5个节点,4段,且是均匀划分的,输出结果为:
tensor([[[[0.2500, 0.5000],
[0.7500, 1.0000]]]])
和输入是一样的,完全符合预期。TestPos1中,空域上划分为2x3,H方向两个节点,W方向3个节点,实际块是1x2,值域上定了3个节点,每个块不一样,因为要验证空域插值,而测试input所有的都是一样,这里设置的是1.0,输出结果为:
tensor([[[[1.0000, 0.8000, 0.6000, 0.7000, 0.8000],
[0.9000, 0.7500, 0.6000, 0.7500, 0.9000],
[0.8000, 0.7000, 0.6000, 0.8000, 1.0000]]]])
也是很符合预期的。
Test2()是验证align_corners=False的情况,且是验证值域,grid里的数值也是保持的直线,4段,input里4个测试只,不过有两个是小于1/8或者大于7/8,输出结果为:
tensor([[[[0.1250, 0.5000],
[0.7500, 0.8750]]]])
可以看到,0.5和0.75的输出也是0.5和0.75,维持不变,符合预期,但1/16的输入,输出却不是1/16,而是1/8,因为超出了边界,使用的边界的值,1的输入也是同理,输出的是7/8,而不是1。