目的:非线性激活的主要作用是提高泛化能力。
视频网址
首先来看看官方文档(以ReLU为例)
其中要注意到参数:inplace,可以举例子解释一下
相当于输出是否覆盖输入,一般情况下inplace=False
(默认值)
代码
import torch
import time
from torch import nn
from torch.nn import ReLU
start = time.time()
input = torch.tensor([[1, -0.5],
[-1, 3]])
input = torch.reshape(input, (-1, 1, 2, 2))
print(input.shape)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu1 = ReLU()
def forward(self, input):
output = self.relu1(input)
return output
model = Model()
output = model(input)
print(output)
end = time.time()
print('Running time: %s Seconds' % (end - start))
输出结果为
D:\Anaconda3\envs\pytorch\python.exe D:/研究生/代码尝试/nn_relu.py
torch.Size([1, 1, 2, 2])
tensor([[[[1., 0.],
[0., 3.]]]])
Running time: 0.05785083770751953 Seconds
进程已结束,退出代码为 0
以sigmoid函数为例,来解释一下非线性激活的作用
import torch
import time
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
start = time.time()
dataset = torchvision.datasets.CIFAR10("./dataset", train=False, download=False, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu1 = ReLU()
self.sigmoid1 = Sigmoid()
def forward(self, input):
output = self.sigmoid1(input)
return output
model = Model()
step = 0
writer = SummaryWriter("./logs_relu")
for data in dataloader:
imgs, targets = data
writer.add_images("input", imgs, global_step=step)
output = model(imgs)
writer.add_images("output", output, step)
step += 1
writer.close()
end = time.time()
print('Running time: %s Seconds' % (end - start))
打开Tensorboard
可以看出,图像都变灰了