1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

if __name__ == '__main__':
original_img = plt.imread('color.png')
print("Shape of original_img is:", original_img.shape)
original_img /= 255
X_img = np.reshape(original_img, (original_img.shape[0] * original_img.shape[1], 4))
# 模型训练
K = 8
model = KMeans(n_clusters=K)
model.fit(X_img)
centroids = model.cluster_centers_
# labels得到的是质心索引
labels = model.predict(X_img)
# print(labels[:6])
# 替换样本
X_recovered = centroids[labels]
X_recovered = np.reshape(X_recovered, original_img.shape)

plt.imshow(X_recovered*255)
plt.axis('off')
plt.show()