经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 程序设计 » Python » 查看文章
python类参数定义及数据扩展方式unsqueeze/expand
来源:jb51  时间:2022/8/23 11:26:50  对本文有异议

类的参数定义

将conda环境设置为ai,conda activate ai

这个文件的由来:

由于在yolov1的pytorch实现的损失函数中,看到继承了nn.Module,并且其中两个参数不像c++那里指定类型,那么他们的类型是哪里来的

这里就是在探索这样一件事

操作逻辑:

  • 先在类中定义了构造函数以及一个自定义函数;
  • 构造函数定义了属性S、B,自定义函数引入两个参数,对两个参数进行调用
    • 这里就说明参数的结构是怎么样的,取决于参数被调用了什么东西,比如这里调用了N = box1.size(0) M = box2.size(0)说明了它是类似一个矩阵的东西,对应的box1的定义就是`torch.rand(10,4)
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5.  
  6. #探究属性S,B是如何产生的,以及box1、box2是如何产生的、如何调用
  7. class yoloLoss(nn.Module):
  8. def __init__(self,S,B):
  9. self.S=S
  10. self.B=B
  11. def compute_iot(self,box1,box2):
  12. N = box1.size(0) #调用方式就表示了变量是什么类型,这里是一个张量,其中每个元素是一个tensor,所以是N*4的张量
  13. M = box2.size(0)
  14. print(M,N)
  15.  
  16. yoloLoss1 =yoloLoss(10, 11)
  17. yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))

数据扩展

探究unsqueeze以及expand的使用方法,unsqueeze可以增加一个纬度,但是维度的siz只是1而已,而expand就可以将数据进行复制,将数据变为n

  1. # 获得一开始的初始化数值:tensor([[a1,a2,a3]])
  2. nn1=torch.rand(1,3)
  3. print(nn1)
  4. # unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]])
  5. nn1=nn1.unsqueeze(0)
  6. print("*"*100)
  7. print(nn1)
  8. #利用expand对数据进行扩展
  9. nn1=nn1.expand(1,3,3)
  10. print("*"*100)
  11. print(nn1)

到此这篇关于python类参数定义及数据扩展方式unsqueeze/expand的文章就介绍到这了,更多相关python unsqueeze/expand内容请搜索w3xue以前的文章或继续浏览下面的相关文章希望大家以后多多支持w3xue!

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号