经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
Pytorch之Variable求导机制
来源:cnblogs  作者:Ruyi.Luo  时间:2018/12/29 9:37:09  对本文有异议

自动求导机制是pytorch中非常重要的性质,免去了手动计算导数,为构建模型节省了时间。下面介绍自动求导机制的基本用法。

#自动求导机制
import torch
from torch.autograd import Variable

# 1、简单的求导(求导对象是标量)
x = Variable(torch.Tensor([2]),requires_grad=True)
y = (x + 2) ** 2 + 3
print(y)
y.backward()
print(x.grad)

#对矩阵求导
x1 = Variable(torch.randn(10,20),requires_grad=True)
y1 = Variable(torch.randn(10,1),requires_grad=True)
W = Variable(torch.randn(20,1),requires_grad=True)

J = torch.mean(y1 - torch.matmul(x1,W)) #matmul表示做矩阵乘法
J.backward()
print(x1.grad)
print(y1.grad)
print(W.grad)

  1. tensor([19.], grad_fn=<AddBackward0>)
  2. tensor([8.])
  3. tensor([[-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  4. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  5. -0.1581, 0.1986, -0.0226, -0.0454],
  6. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  7. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  8. -0.1581, 0.1986, -0.0226, -0.0454],
  9. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  10. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  11. -0.1581, 0.1986, -0.0226, -0.0454],
  12. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  13. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  14. -0.1581, 0.1986, -0.0226, -0.0454],
  15. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  16. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  17. -0.1581, 0.1986, -0.0226, -0.0454],
  18. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  19. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  20. -0.1581, 0.1986, -0.0226, -0.0454],
  21. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  22. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  23. -0.1581, 0.1986, -0.0226, -0.0454],
  24. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  25. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  26. -0.1581, 0.1986, -0.0226, -0.0454],
  27. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  28. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  29. -0.1581, 0.1986, -0.0226, -0.0454],
  30. [-0.1636, 0.0904, 0.0446, -0.1052, -0.2323, 0.0129, -0.1532, 0.0544,
  31. 0.0231, -0.0993, -0.0387, -0.1762, 0.0477, 0.1552, 0.0493, 0.0144,
  32. -0.1581, 0.1986, -0.0226, -0.0454]])
  33. tensor([[0.1000],
  34. [0.1000],
  35. [0.1000],
  36. [0.1000],
  37. [0.1000],
  38. [0.1000],
  39. [0.1000],
  40. [0.1000],
  41. [0.1000],
  42. [0.1000]])
  43. tensor([[ 0.0224],
  44. [ 0.0187],
  45. [-0.2078],
  46. [ 0.5092],
  47. [ 0.0677],
  48. [ 0.3497],
  49. [-0.4575],
  50. [-0.5480],
  51. [ 0.4228],
  52. [-0.0869],
  53. [ 0.2876],
  54. [-0.1714],
  55. [ 0.0985],
  56. [-0.1364],
  57. [-0.1502],
  58. [-0.1372],
  59. [-0.0999],
  60. [-0.0006],
  61. [-0.0544],
  62. [-0.0678]])

#复杂情况的自动求导 多维数组自动求导机制
import torch
from torch.autograd import Variable

x = Variable(torch.FloatTensor([3]),requires_grad=True)
y = x ** 2 + x * 2 + 3
y.backward(retain_graph=True) #保留计算图
print(x.grad)
y.backward()#不保留计算图
print(x.grad) #得到的是第一次求导的值加上第二次求导的值 8 + 8

  1. tensor([8.])
  2. tensor([16.])

#小练习,向量对向量求导
import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([2,3]),requires_grad = True)
k = Variable(torch.zeros_like(x))

k[0] = x[0]**2 + 3 * x[1]
k[1] = 2*x[0] + x[1] ** 2

print(k)

j = torch.zeros(2,2)
k.backward(torch.FloatTensor([1,0]),retain_graph = True)
j[0] = x.grad.data

x.grad.zero_()
k.backward(torch.FloatTensor([0,1]),retain_graph = True)
j[1] = x.grad.data
print(j)

  1. tensor([13., 13.], grad_fn=<CopySlices>)
  2. tensor([[4., 3.],
  3. [2., 6.]])
 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号