经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
使用tensorflow实现mnist手写识别(单层神经网络实现)
来源:cnblogs  作者:路漫漫其修远兮,(...)  时间:2019/4/2 8:40:41  对本文有异议
  1. import tensorflow as tf
  2. import tensorflow.examples.tutorials.mnist.input_data as input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. mnist = input_data.read_data_sets("data/",one_hot = True)
  6. #导入Tensorflwo和mnist数据集
  7.  
  8. #构建输入层
  9. x = tf.placeholder(tf.float32,[None,784],name='X')
  10. y = tf.placeholder(tf.float32,[None,10],name='Y')
  11. #隐藏层神经元数量
  12. H1_NN = 256 #第一层神经元数量
  13. W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
  14. b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
  15. Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
  16. W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
  17. b2 = tf.Variable(tf.zeros(10))#偏置项
  18. forward = tf.matmul(Y1,W2)+b2 #定义前向传播
  19. pred = tf.nn.softmax(forward) #激活函数输出
  20.  
  21. #损失函数
  22. #loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
  23. # reduction_indices=1))
  24. #(log(0))超出范围报错
  25. loss_function = tf.reduce_mean(
  26. tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y))
  27. #训练参数
  28. train_epochs = 50 #训练次数
  29. batch_size = 50 #每次训练多少个样本
  30. total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
  31. display_step = 1 #训练情况输出
  32. learning_rate = 0.01 #学习率
  33.  
  34. #优化器
  35. opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
  36. #准确率函数
  37. correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
  38. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  39. #记录开始训练时间
  40. from time import time
  41. startTime = time()
  42. #初始化变量
  43. sess =tf.Session()
  44. init = tf.global_variables_initializer()
  45. sess.run(init)
  46. #训练
  47. for epoch in range(train_epochs):
  48. for batch in range(total_batch):
  49. xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
  50. sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练
  51. #total_batch个批次训练完成后,使用验证数据计算误差与准确率
  52. loss,acc=sess.run([loss_function,accuracy],
  53. feed_dict={x:mnist.validation.images,
  54. y:mnist.validation.labels})
  55. #输出训练情况
  56. if(epoch+1) % display_step == 0:
  57. print("Train Epoch:",'%02d' % (epoch + 1),
  58. "Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
  59. duration = time()-startTime
  60. print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时
  61.  
  62. #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
  63. prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
  64. prediction_resul[0:10]
  65. #模型评估
  66. accu_test = sess.run(accuracy,
  67. feed_dict={x:mnist.test.images,y:mnist.test.labels})
  68. print("Accuray:",accu_test)
  69. compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
  70. print(compare_lists)
  71. err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
  72. print(err_lists,len(err_lists))
  73. index_list = []
  74. def print_predct_errs(labels,#标签列表
  75. perdiction):#预测值列表
  76. count = 0
  77. compare_lists = (perdiction == np.argmax(labels,1))
  78. err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
  79. for x in err_lists:
  80. index_list.append(x)
  81. print("index="+str(x)+
  82. "标签值=",np.argmax(labels[x]),
  83. "预测值=",perdiction[x])
  84. count = count+1
  85. print("总计:",count)
  86. return index_list
  87. print_predct_errs(mnist.test.labels,prediction_resul)
  88. def plot_images_labels_prediction(images,labels,prediction,index,num=25):
  89. fig = plt.gcf() # 获取当前图片
  90. fig.set_size_inches(10,12)
  91. if num>=25:
  92. num=25 #最多显示25张图片
  93. for i in range(0,num):
  94. ax = plt.subplot(5,5, i+1) #获取当前要处理的子图
  95. ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
  96. title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
  97. if len(prediction)>0:
  98. title += 'predict= '+str(prediction[index])
  99. ax.set_title(title,fontsize=10)
  100. ax.set_xticks([])
  101. ax.set_yticks([])
  102. index += 1
  103. plt.show()
  104. plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])

单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸

原文链接:http://www.cnblogs.com/imae/p/10629890.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号