- 1 # -*- coding: utf-8 -*-
- 2 """
- 3 Created on Fri Sep 21 15:37:26 2018
- 4
- 5 @author: zhen
- 6 """
- 7 from PIL import Image
- 8 import numpy as np
- 9 from sklearn.cluster import KMeans
- 10 import matplotlib
- 11 import matplotlib.pyplot as plt
- 12
- 13 def restore_image(cb, cluster, shape):
- 14 row, col, dummy = shape
- 15 image = np.empty((row, col, dummy))
- 16 for r in range(row):
- 17 for c in range(col):
- 18 image[r, c] = cb[cluster[r * col + c]]
- 19 return image
- 20
- 21 def show_scatter(a):
- 22 N = 10
- 23 density, edges = np.histogramdd(a, bins=[N, N, N], range=[(0, 1), (0, 1), (0, 1)])
- 24 density /= density.max()
- 25 x = y = z = np.arange(N)
- 26 d = np.meshgrid(x, y, z)
- 27
- 28 fig = plt.figure(1, facecolor='w')
- 29 ax = fig.add_subplot(111, projection='3d')
- 30
- 31 cm = matplotlib.colors.ListedColormap(list('rgbm'))
- 32 ax.scatter(d[0], d[1], d[2], s=100 * density, cmap=cm, marker='o', depthshade=True)
- 33 ax.set_xlabel(u'红')
- 34 ax.set_ylabel(u'绿')
- 35 ax.set_zlabel(u'蓝')
- 36 plt.title(u'图像颜色三维频数分布', fontsize=20)
- 37
- 38 plt.figure(2, facecolor='w')
- 39 den = density[density > 0]
- 40 den = np.sort(den)[::-1]
- 41 t = np.arange(len(den))
- 42 plt.plot(t, den, 'r-', t, den, 'go', lw=2)
- 43 plt.title(u'图像颜色频数分布', fontsize=18)
- 44 plt.grid(True)
- 45
- 46 plt.show()
- 47
- 48 if __name__ == '__main__':
- 49 matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
- 50 matplotlib.rcParams['axes.unicode_minus'] = False
- 51 # 聚类数2,6,30
- 52 num_vq = 2
- 53 im = Image.open('C:/Users/zhen/.spyder-py3/images/Lena.png')
- 54 image = np.array(im).astype(np.float) / 255
- 55 image = image[:, :, :3]
- 56 image_v = image.reshape((-1, 3))
- 57 kmeans = KMeans(n_clusters=num_vq, init='k-means++')
- 58 show_scatter(image_v)
- 59
- 60 N = image_v.shape[0] # 图像像素总数
- 61 # 选择样本,计算聚类中心
- 62 idx = np.random.randint(0, N, size=int(N * 0.7))
- 63 image_sample = image_v[idx]
- 64 kmeans.fit(image_sample)
- 65 result = kmeans.predict(image_v) # 聚类结果
- 66 print('聚类结果:\n', result)
- 67 print('聚类中心:\n', kmeans.cluster_centers_)
- 68
- 69 plt.figure(figsize=(15, 8), facecolor='w')
- 70 plt.subplot(211)
- 71 plt.axis('off')
- 72 plt.title(u'原始图片', fontsize=18)
- 73 plt.imshow(image)
- 74 # plt.savefig('原始图片.png')
- 75
- 76 plt.subplot(212)
- 77 vq_image = restore_image(kmeans.cluster_centers_, result, image.shape)
- 78 plt.axis('off')
- 79 plt.title(u'聚类个数:%d' % num_vq, fontsize=20)
- 80 plt.imshow(vq_image)
- 81 # plt.savefig('矢量化图片.png')
- 82
- 83 plt.tight_layout(1.2)
- 84 plt.show()