0
点赞
收藏
分享

微信扫一扫

Python *args参数的作用

mjjackey 2022-02-19 阅读 49

用作函数占位符,可以增加扩展性

比如在深度学习中,网络和训练函数部分代码如下:

class SubNet(nn.Module):
	...
	def forward(self, X, *args):
		pred = do_something(X, *args)
		return pred

class Net(nn.Module):
	...
	def forward(self, X, *args):
		do_something()
		return self.subnet(X, *args)

def train_fn(net, loss, optimizer, train_iter, n_epochs):
	...
	for batch in train_iter:
		X, y = batch
		pred = net(X)
		l = loss(pred, y)
		optimizer.zero_grad()
		l.backward()
		optimizer.step()
	...

在这种情况下,模型有几个部分组成,使用如下的代码即可完成训练

subnet = SubNet()
net = Net(subnet)
train_fn(net, ...)

当子网络改进之后,模型要求的参数变多之后

class SubNet2(nn.Module):
	...
	def forward(self, X, W, *args):
		pred = do_something(X, W, *args)
		return pred
	
def train_fn(net, loss, optimizer, train_iter, n_epochs):
	...
	W = some_code()
	for batch in train_iter:
		X, y = batch
		pred = net(X, W)
		l = loss(pred, y)
		optimizer.zero_grad()
		l.backward()
		optimizer.step()
	...

训练的代码只需要将子网络改变成改进后的网络注入到模型中即可像之前一样进行训练(train_fn需要做一些适应性调整)

subnet = SubNet()
net = Net(subnet)
train_fn(net)

在这个时候,train_fn中net给的参数相比第一次的参数增加了,但是不需要对net的定义进行修改,新增加的参数会自动以*args的形式传入net,在调用self.subnet(X, *args)的时候自动进行传递,实现一定程度的解耦

举报

相关推荐

0 条评论