经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
kNN--近邻算法
来源:cnblogs  作者:孩纸有点硬  时间:2018/12/10 9:22:27  对本文有异议

kNN--近邻算法

   kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。

在机器学习中常用于分类。

数学内容:

  欧氏距离公式,矩阵运算,归一化数值

python模块:

  numpy,operator(用其中的itemgetter做排序),listdir(列出目录中的文件),matplotlib.pyplot(可视化数据分析数据),

  PIL(对图片进行处理)

  1. from numpy import *
  2. import operator
  3. from os import listdir
  4. def createDataSet():
  5. groups=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
  6. lables=['A','A','B','B']
  7. return groups,lables
  8. #k-近邻算法
  9. def classify0(inX, dataset,labels,k):
  10. #获取样本集中有几组数据
  11. datasetSize=dataset.shape[0]
  12. #欧氏距离公式 计算距离
  13. diffMat=tile(inX, (datasetSize, 1)) - dataset
  14. sqDiffMat=diffMat**2
  15. sqDistances=sqDiffMat.sum(axis=1)
  16. distances=sqDistances**0.5
  17. #按距离递增排列,返回样本集中的index
  18. sortedDistances=distances.argsort()
  19. classCount={}
  20. for i in range(k):
  21. #根据距离递增的顺序,获取与其对应的类别(即目标变量)
  22. voteIlabel=labels[sortedDistances[i]]
  23. #为k个元素所在的分类计数
  24. classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
  25. #通过对比每个类别出现的次数(即classCount value),以递减的顺序排序
  26. sortedCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
  27. #返回计数最大的那个类别的值
  28. return sortedCount[0][0]
  29. #准备数据
  30. def file2matrix(filename):
  31. fr=open(filename)
  32. arrayOLines=fr.readlines()
  33. #获取文件行数
  34. numberOflines=len(arrayOLines)
  35. #创建一个以文件行数为行,3列的矩阵
  36. returnMatrix=zeros((numberOflines,3))
  37. #定义一个存放目标变量(类别)的数组
  38. classLabelVector=[]
  39. index=0
  40. #遍历文件
  41. for line in arrayOLines:
  42. line=line.strip()
  43. listFromLine=line.split('\t')
  44. #把文件前三列添加到返回的矩阵中
  45. returnMatrix[index:]=listFromLine[0:3]
  46. #文件最后一列(对应的类别)添加到类别数组中
  47. classLabelVector.append(int(listFromLine[-1]))
  48. index+=1
  49. #返回数据特征矩阵和类别数组
  50. return returnMatrix,classLabelVector
  51. #通过公式 "newValue=(oldValue-min)/(max-min)" 将任意取值范围的特征值转化为0到1区间内的值
  52. def autoNorm(dataset):
  53. #返回每列的最小值
  54. minVals=dataset.min(0)
  55. #返回每列的最大值
  56. maxVals=dataset.max(0)
  57. #返回最大值与最小值的差
  58. ranges=maxVals-minVals
  59. #创建与dataset同行同列的0矩阵
  60. normDataset=zeros(shape(dataset))
  61. #返回dataset的行数
  62. m=dataset.shape[0]
  63. #创建一个重复m次的minVals矩阵,并与dataset相减
  64. normDataset=dataset-tile(minVals,(m,1))
  65. #newValue=(oldValue-min)/(max-min)
  66. normDataset=normDataset/tile(ranges,(m,1))
  67. return normDataset,ranges,minVals
  68. #测试算法
  69. def datingClassTest():
  70. #设定测试数据比例
  71. hoRatio=0.10
  72. #返回格式化后的数据和其标签
  73. datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
  74. #归一化数据值
  75. normMat,ranges,minVals=autoNorm(datingDataMat)
  76. #数据的行数
  77. m=normMat.shape[0]
  78. #测试数据的行数
  79. numTestVecs=int(m*hoRatio)
  80. #设置错误预测计数器
  81. errorCount=0.0
  82. #向k-近邻算法中传numTestVecs个测试数据,并把返回的预测数据与真实数据比较返回,若错误,计数器加1
  83. for i in range(numTestVecs):
  84. """
  85. 调用k-近邻算法,为其传入参数,
  86. normMat[i]:第i个测试数据,
  87. normMat[numTestVecs:m,:]:从numTestVecs到m个样本数据,(m可以不写,相当于从numTestVecs索引开始,取剩下所有的normMat数据)
  88. datingLabels[numTestVecs:m]:从numTestVecs到m个样本数据对应的标签
  89. 3:k的值
  90. """
  91. classifierResult=classify0(normMat[i],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
  92. #判断预测数据与真实数据,如果是错误的,则以红字体输出,并错误预测计数器加1
  93. if (classifierResult!=datingLabels[i]):
  94. print("\033[0;31mthe classifier came back with: %d, the real answer is: %d\033[0m" % (classifierResult, datingLabels[i]))
  95. errorCount+=1.0
  96. else:
  97. print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
  98. print("the total error rate is:%f" %(errorCount/float(numTestVecs)))
  99. #约会系统
  100. def classifiyPerson():
  101. #设定分类(标签)列表
  102. resultList=["not at all", "in small doses", "in large doses"]
  103. #提示用户输入相应内容
  104. percentTats=float(input("percentage of time spent playing video games?"))
  105. ffMiles=float(input("frequent filer miles earned per year?"))
  106. iceCream=float(input("liters of ice cream consumed per year?"))
  107. #把用户输入的三个特征值格式化成numpy.array数据类型
  108. inArr=array([ffMiles,percentTats,iceCream])
  109. #准备样本数据及对应标签
  110. datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
  111. #归一化样本数据并返回ranges和minVals,以便归一化用户输入的数据
  112. normMat,ranges,minVals=autoNorm(datingDataMat)
  113. #调用k-近邻算法,并把传入的预测数据特征做归一化
  114. classifierResult=classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
  115. #打印出预测出的类别,因为样本数据中的类别(标签)为1,2,3,其index是0,1,2,所以要用预测出的分类(1,2,3)减1
  116. print("You will probably like this person: %s" %(resultList[classifierResult-1]))
  117. #将32x32的二进制图像文件转换成1x1024的向量
  118. def img2vector(filename):
  119. #创建一个1x1024的0矩阵
  120. returnVect=zeros((1,1024))
  121. fr=open(filename)
  122. """
  123. 因为已知文件是32x32,即有文件中32行内容,通过readline()方法遍历文件,得到文件的每行内容lineStr
  124. 再遍历每行内容lineStr,并把遍历出的内容添加到returnVect矩阵里
  125. """
  126. for i in range(32):
  127. lineStr=fr.readline()
  128. for j in range(32):
  129. returnVect[0,32*i+j]=int(lineStr[j])
  130. return returnVect
  131. #手写数字识别系统
  132. def handwritingClassTest():
  133. #创建数据标签集合
  134. hwLabels=[]
  135. #列出目录冲所有文件
  136. trainingFileList=listdir('digits/trainingDigits')
  137. #得到文件个数,也就是训练数据的行数
  138. m=len(trainingFileList)
  139. #创建一个m行,1024列的0矩阵
  140. trainingMat=zeros((m,1024))
  141. """
  142. 通过遍历所有训练文件,得到文件名,其对应的数字(eg:0_7.txt),并把数字添加到hwLabels集合,
  143. 通过上面的img2vector函数,得到一个与该文件对应的1x1024矩阵,并添加到trainingMat矩阵中
  144. """
  145. for i in range(m):
  146. fileNameStr=trainingFileList[i]
  147. fileStr=fileNameStr.split('.')[0]
  148. classNumStr=int(fileStr.split('_')[0])
  149. hwLabels.append(classNumStr)
  150. trainingMat[i,:]=img2vector('digits/trainingDigits/%s' % fileNameStr)
  151. #对测试数据做同样的操作
  152. testFileList=listdir('digits/testDigits')
  153. mTest=len(testFileList)
  154. errorCount=0.0
  155. for i in range(mTest):
  156. fileNameStr=testFileList[i]
  157. fileStr=fileNameStr.split('.')[0]
  158. classNumStr=int(fileStr.split('_')[0])
  159. vectorUnderTest=img2vector('digits/testDigits/%s' % fileNameStr)
  160. classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)
  161. if (classifierResult!=classNumStr):
  162. print("\033[0;31mthe classifier came back with: %d, the real answer is: %d\033[0m" % (classifierResult,classNumStr))
  163. errorCount+=1
  164. else:
  165. print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))
  166. print("\nthe total number of errors is: %d" % errorCount)
  167. print("\nthe total error rate is: %f" %(errorCount/float(mTest)))
  168. #在网上找数字图片做测试
  169. def imgNumClassTest(filename):
  170. hwLabels=[]
  171. trainingFileList=listdir('digits/trainingDigits')
  172. m=len(trainingFileList)
  173. trainingMat=zeros((m,1024))
  174. for i in range(m):
  175. fileNameStr=trainingFileList[i]
  176. fileStr=fileNameStr.split('.')[0]
  177. classNumStr=int(fileStr.split('_')[0])
  178. hwLabels.append(classNumStr)
  179. trainingMat[i,:]=img2vector('digits/trainingDigits/%s' % fileNameStr)
  180. vectorUnderTest=img2vector(filename)
  181. classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)
  182. print(classifierResult)

 

 

约会网站案列数据分析代码:

  1. """
  2. 分析数据
  3. """
  4.  
  5. import kNN
  6. from numpy import *
  7. import matplotlib
  8. import matplotlib.pyplot as plt
  9. datingDataMat,datingLabels=kNN.file2matrix('datingTestSet2.txt')
  10. #创建一个图片窗口,默认是1(figure1)
  11. fig=plt.figure()
  12. #在图片窗口创建两行一列的子图,并使用第一行第一列,即211的含义
  13. ax1=fig.add_subplot(211)
  14. """
  15. 创建散点图,x轴是datingDataMat第一列的数据,y轴是datinDataMat第二列的数据,
  16. 后面两个参数一个代表颜色,一个代表点的大小,两个参数同时放大15倍,然后这个时候就是同一个label用一种颜色和大小表示出来,
  17. 不同的label的点的大小和颜色会不一样。
  18. """
  19. ax1.scatter(datingDataMat[:,1],datingDataMat[:,2],15*array(datingLabels),15*array(datingLabels))
  20. #设置x轴标签
  21. plt.xlabel('Play game takes time')
  22. #设置y轴标签
  23. plt.ylabel('Eat ice-cream')
  24. #在图片窗口中使用第一行第二列
  25. ax2=fig.add_subplot(212)
  26. #把datingLabels转成numpy.array类型
  27. datingLabels=array(datingLabels)
  28. #取datingLabels中值等于1的index
  29. idx_1=where(datingLabels==1)
  30. #idx_1即datingTestSet2.txt文件中第四列值为1的行数,则获取idx_1行,第一二列的数据创建散点图,为这些点设置颜色,大小,label
  31. p1=ax2.scatter(datingDataMat[idx_1,0],datingDataMat[idx_1,1],color = 'm', label='Hate', s = 50)
  32. idx_2=where(datingLabels==2)
  33. p2=ax2.scatter(datingDataMat[idx_2,0],datingDataMat[idx_2,1],color = 'c', label='General', s = 30)
  34. idx_3=where(datingLabels==3)
  35. p3=ax2.scatter(datingDataMat[idx_3,0],datingDataMat[idx_3,1],color = 'r', label='Like', s = 10)
  36. plt.xlabel('Flying')
  37. plt.ylabel('Play game takes time')
  38. #创建图示放置在左上角
  39. plt.legend(loc='upper left')
  40. #显示图片
  41. plt.show()

手写数字识别系统图片转文本文件代码:

  1. from PIL import Image
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. def img2txt(img_path, txt_name):
  5. """
  6. 将图像数据转换为txt文件
  7. :param img_path: 图像文件路径
  8. :type txt_name: 输出txt文件路径
  9. """
  10.  
  11. #把图片转成二值图像,并设长宽均为32
  12. im = Image.open(img_path).convert('1').resize((32, 32)) # type:Image.Image
  13. #plt.imshow(im)
  14. #plt.show()
  15. #将上面得到的图像转成array数组
  16. data = np.asarray(im)
  17. #将上面得到的数组保存在到文本文件中,指定存储数据类型为整型,分隔符
  18. np.savetxt(txt_name, data, fmt='%d', delimiter='')

 

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号