经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
基于tensorflow使用全连接层函数实现多层神经网络并保存和读取模型
来源:cnblogs  作者:路漫漫其修远兮,(...)  时间:2019/4/1 8:52:08  对本文有异议

使用之前那个格式写法到后面层数多的话会很乱,所以编写了一个函数创建层,这样看起来可读性高点也更方便整理后期修改维护

  1. #全连接层函数
  2.  
  3. def fcn_layer(
  4. inputs, #输入数据
  5. input_dim, #输入层神经元数量
  6. output_dim,#输出层神经元数量
  7. activation =None): #激活函数
  8. W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
  9. #以截断正态分布的随机初始化W
  10. b = tf.Variable(tf.zeros([output_dim]))
  11. #以0初始化b
  12. XWb = tf.matmul(inputs,W)+b # Y=WX+B
  13. if(activation==None): #默认不使用激活函数
  14. outputs =XWb
  15. else:
  16. outputs = activation(XWb) #代入参数选择的激活函数
  17. return outputs #返回
  1. #各层神经元数量设置
  2. H1_NN = 256
  3. H2_NN = 64
  4. H3_NN = 32
  5.  
  6. #构建输入层
  7. x = tf.placeholder(tf.float32,[None,784],name='X')
  8. y = tf.placeholder(tf.float32,[None,10],name='Y')
  9. #构建隐藏层
  10. h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
  11. h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
  12. h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
  13. #构建输出层
  14. forward = fcn_layer(h3,H3_NN,10,None)
  15. pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数

这样写方便后期维护 不必对着一群 W1 W2..... Wn

接下来记录一下保存模型的方法

  1. #保存模型
  2. save_step = 5 #储存模型力度
  3. import os
  4. ckpt_dir = '.ckpt_dir/'
  5. if not os.path.exists(ckpt_dir):
  6. os.makedirs(ckpt_dir)

  5轮训练保存一次,以后大模型可以调高点,接下来需要在模型整合处修改一下

  1. saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录
  2. if(epochs+1) % save_step == 0:
      saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))#储存模型
      print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))#输出情况

至此储存模型结束

 

接下来是还原模型,要注意还原的模型层数和神经元数量大小需要和之前储存模型的大小一致。

第一步设置保存模型文件的路径

  1. #必须指定存储位置
  2. ckpt_dir = "/ckpt_dir/"

存盘只会保存最近的5次,恢复会恢复最新那一份

  1. #恢复模型,创建会话
  2.  
  3. saver = tf.train.Saver()
  4.  
  5. sess = tf.Session()
  6. init = tf.global_variables_initializer()
  7. sess.run(init)
  8.  
  9. ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
  10. if ckpt and ckpt.model_checkpoint_path:
  11. saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
  12. print("Restore model from"+ckpt.model_checkpoint_path)

 至此模型恢复完成 下面可以选择继续训练或者评估使用

最后附上完整代码

  1. import tensorflow as tf
  2. import tensorflow.examples.tutorials.mnist.input_data as input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from time import time
  6. mnist = input_data.read_data_sets("data/",one_hot = True)
  7. #导入Tensorflwo和mnist数据集等 常用库
  8. #全连接层函数
  9.  
  10. def fcn_layer(
  11. inputs, #输入数据
  12. input_dim, #输入层神经元数量
  13. output_dim,#输出层神经元数量
  14. activation =None): #激活函数
  15. W = tf.Variable(tf.truncated_normal([input_dim,output_dim],stddev = 0.1))
  16. #以截断正态分布的随机初始化W
  17. b = tf.Variable(tf.zeros([output_dim]))
  18. #以0初始化b
  19. XWb = tf.matmul(inputs,W)+b # Y=WX+B
  20. if(activation==None): #默认不使用激活函数
  21. outputs =XWb
  22. else:
  23. outputs = activation(XWb) #代入参数选择的激活函数
  24. return outputs #返回
  25. #各层神经元数量设置
  26. H1_NN = 256
  27. H2_NN = 64
  28. H3_NN = 32
  29.  
  30. #构建输入层
  31. x = tf.placeholder(tf.float32,[None,784],name='X')
  32. y = tf.placeholder(tf.float32,[None,10],name='Y')
  33. #构建隐藏层
  34. h1 = fcn_layer(x,784,H1_NN,tf.nn.relu)
  35. h2 = fcn_layer(h1,H1_NN,H2_NN,tf.nn.relu)
  36. h3 = fcn_layer(h2,H2_NN,H3_NN,tf.nn.relu)
  37. #构建输出层
  38. forward = fcn_layer(h3,H3_NN,10,None)
  39. pred = tf.nn.softmax(forward)#输出层分类应用使用softmax当作激活函数
  40. #损失函数使用交叉熵
  41. loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward,labels = y))
  42. #设置训练参数
  43. train_epochs = 50
  44. batch_size = 50
  45. total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
  46. learning_rate = 0.01
  47. display_step = 1
  48. #优化器
  49. opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
  50. #定义准确率
  51. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
  52. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  53. #保存模型
  54. save_step = 5 #储存模型力度
  55. import os
  56. ckpt_dir = '.ckpt_dir/'
  57. if not os.path.exists(ckpt_dir):
  58. os.makedirs(ckpt_dir)
  59. #开始训练
  60. sess = tf.Session()
  61. init = tf.global_variables_initializer()
  62. saver = tf.train.Saver() #声明完所有变量以后,调用tf.train.Saver开始记录
  63. startTime = time()
  64. sess.run(init)
  65. for epochs in range(train_epochs):
  66. for batch in range(total_batch):
  67. xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
  68. sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练
  69. #total_batch个批次训练完成后,使用验证数据计算误差与准确率
  70. loss,acc = sess.run([loss_function,accuracy],
  71. feed_dict={
  72. x:mnist.validation.images,
  73. y:mnist.validation.labels})
  74. #输出训练情况
  75. if(epochs+1) % display_step == 0:
  76. epochs += 1
  77. print("Train Epoch:",epochs,
  78. "Loss=",loss,"Accuracy=",acc)
  79. if(epochs+1) % save_step == 0:
  80. saver.save(sess, os.path.join(ckpt_dir,"mnist_h256_model_{:06d}.ckpt".format(epochs+1)))
  81. print("mnist_h256_model_{:06d}.ckpt saved".format(epochs+1))
  82. duration = time()-startTime
  83. print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时
  84. #评估模型
  85. accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
  86. print("model accuracy:",accu_test)
  87. #恢复模型,创建会话
  88. saver = tf.train.Saver()
  89. sess = tf.Session()
  90. init = tf.global_variables_initializer()
  91. sess.run(init)
  92. ckpt = tf.train.get_checkpoint_state(ckpt_dir)#选择模型保存路径
  93. if ckpt and ckpt.model_checkpoint_path:
  94. saver.restore(sess ,ckpt.model_checkpoint_path)#从已保存模型中读取参数
  95. print("Restore model from"+ckpt.model_checkpoint_path)
完整代码

 

  

原文链接:http://www.cnblogs.com/imae/p/10634234.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号