123456789101112131415161718192021222324import numpy as npimport matplotlib.pyplot as pltfrom sklearn.cluster import KMeansif __name__ == '__main__': original_img = plt.imread('color.png') print("Shape of original_img is:", original_img.shape) original_img /= 255 X_img = np.reshape(original_img, (original_img.shape[0] * original_img.shape[1], 4)) # 模型训练 K = 8 model = KMeans(n_clusters=K) model.fit(X_img) centroids = model.cluster_centers_ # labels得到的是质心索引 labels = model.predict(X_img) # print(labels[:6]) # 替换样本 X_recovered = centroids[labels] X_recovered = np.reshape(X_recovered, original_img.shape) plt.imshow(X_recovered*255) plt.axis('off') plt.show()