经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
pytorch中的add_module函数
来源:cnblogs  作者:蒙面的普罗米修斯  时间:2021/6/21 9:38:42  对本文有异议

现只讲在自定义网络中add_module的作用。

总结:

在自定义网络的时候,由于自定义变量不是Module类型(例如,我们用List封装了几个网络),所以pytorch不会自动注册网络模块add_module函数用来为网络添加模块的,所以我们可以使用这个函数手动添加自定义的网络模块。当然,这种情况,我们也可以使用ModuleList来封装自定义模块,pytorch就会自动注册了。

 

Let't start!

add_module函数是在自定义网络添加子模块,例如,当我们自定义一个网络肤过程中,我们既可以

(1)通过self.module=xxx_module的方式(如下面第3行代码),添加网络模块;

(2)通过add_module函数对网络中添加模块。

(3)通过用nn.Sequential对模块进行封装等等。

  1. 1 class NeuralNetwork(nn.Module):
  2. 2 def __init__(self):
  3. 3 super(NeuralNetwork, self).__init__()
  4. 4 self.layers = nn.Linear(28*28,28*28)
  5. 5 # self.add_module('layers',nn.Linear(28*28,28*28)) # 跟上面的方式等价
  6. 6 self.linear_relu_stack = nn.Sequential(
  7. 7 nn.Linear(28*28, 512),
  8. 8 nn.ReLU()
  9. 9 )
  10. 10
  11. 11 def forward(self, x):
  12. 12 for layer in layers:
  13. 13 x = layer(x)
  14. 14 logits = self.linear_relu_stack(x)
  15. 15 return logits

我们实例化类,然后输出网络的模块看一下:

  1. 1 0 Linear(in_features=784, out_features=784, bias=True)
  2. 2 1 Sequential(
  3. 3 (0): Linear(in_features=784, out_features=512, bias=True)
  4. 4 (1): ReLU()
  5. 5 )

会发现,上面定义的网络子模块都有:Linear和Sequential。

 

但是,有时候pytorch不会自动给我们注册模块,我们需要根据传进来的参数对网络进行初始化,例如:

 

  1. 1 class NeuralNetwork(nn.Module):
  2. 2 def __init__(self, layer_num):
  3. 3 super(NeuralNetwork, self).__init__()
  4. 4 self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
  5. 5 self.linear_relu_stack = nn.Sequential(
  6. 6 nn.Linear(28*28, 512),
  7. 7 nn.ReLU()
  8. 8 )
  9. 9
  10. 10 def forward(self, x):
  11. 11 for layer in layers:
  12. 12 x = layer(x)
  13. 13 logits = self.linear_relu_stack(x)
  14. 14 return logits

对此我们再初始化一个实例,然后看下网络中的模块:

  1. 1 model = NeuralNetwork(2)
  2. 2 for index,item in enumerate(model.children()):
  3. 3 print(index,item)

输出结果就是:

  1. 0 Sequential(
  2. (0): Linear(in_features=784, out_features=512, bias=True)
  3. (1): ReLU()
  4. ) 

 

你会发现定义的Linear模块都不见了,而上面定义的时候,明明都制订了。这是因为pytorch在注册模块的时候,会查看成员的类型,如果成员变量类型是Module的子类,那么pytorch就会注册这个模块,否则就不会。

这里的self.layers是python中的List类型,所以不会自动注册,那么就需要我们再定义后,手动注册(下图黄色标注部分):

  1. 1 class NeuralNetwork(nn.Module):
  2. 2 def __init__(self, layer_num):
  3. 3 super(NeuralNetwork, self).__init__()
  4. 4 self.layers = [nn.Linear(28*28,28*28) for _ in range(layer_num)]
  5. 5 for i,layer in enumerate(self.layers):
  6. 6 self.add_module('layer_{}'.format(i),layer)
  7. 7 self.linear_relu_stack = nn.Sequential(
  8. 8 nn.Linear(28*28, 512),
  9. 9 nn.ReLU()
  10. 10 )
  11. 11
  12. 12 def forward(self, x):
  13. 13 for layer in layers:
  14. 14 x = layer(x)
  15. 15 logits = self.linear_relu_stack(x)
  16. 16 return logits

这样我们再输出模型的子模块的时候,就会得到:

  1. model = NeuralNetwork(4)
  2. for index,item in enumerate(model.children()):
  3. print(index,item)
  4. # output
  5. #0 Linear(in_features=784, out_features=784, bias=True)
  6. #1 Linear(in_features=784, out_features=784, bias=True)
  7. #2 Linear(in_features=784, out_features=784, bias=True)
  8. #3 Linear(in_features=784, out_features=784, bias=True)
  9. #4 Sequential(
  10. # (0): Linear(in_features=784, out_features=512, bias=True)
  11. # (1): ReLU()
  12. #)

就会看到,已经有了自己注册的模块。

 

当然,也可能觉得这种方式比较麻烦,每次都要自己注册下,那能不能有一个类似List的类,在定义的时候就封装一下呢? 

可以,使用nn.ModuleList封装一下即可达到相同的效果。

  1. class NeuralNetwork(nn.Module):
  2. def __init__(self, layer_num):
  3. super(NeuralNetwork, self).__init__()
  4. self.layers = nn.ModuleList([nn.Linear(28*28,28*28) for _ in range(layer_num)])
  5. self.linear_relu_stack = nn.Sequential(
  6. nn.Linear(28*28, 512),
  7. nn.ReLU()
  8. )
  9. def forward(self, x):
  10. for layer in layers:
  11. x = layer(x)
  12. logits = self.linear_relu_stack(x)
  13. return logits

 

参考:
1. 博客THE PYTORCH ADD_MODULE() FUNCTION link
2. pytorch 官方文档 中文链接 English version

原文链接:http://www.cnblogs.com/datasnail/p/14903643.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号