본문 바로가기

컴퓨터공학/mini project

[python] K means clustering을 구현해보자 [작성중]

이번 글에서는 K means clustering을 python을 이용하여 구현해 보도록 하겠습니다.

index

 

 

#K means clustering 
#1. K를 정한다 2. k개의 초기 중심을 임의로 정한다.
#3. 각 데이터 포인트~ k개의 중심거리를 계산하여 가장 가까운 중심점이 속한 클러스터로 이동한다.
#4. 2,3단계를 반복하고 더이상 클러스터에 변화가 없으면 종료한다.
import numpy as np
import matplotlib.pyplot as plt

plt.style.use('seaborn')

class DatasetGnerator:
    def __init__(self, n_cluster, n_cluster_data):
        self.n_cluster = n_cluster
        self.n_cluster_data = n_cluster_data
        #
        self.dataset = list()
        self._make_dataset()
    def _make_dataset(self):
        for n_cluster_idx in range(self.n_cluster):
            center = np.random.uniform(-10,10,(2,))
            cluster_data = np.random.normal(loc = center, scale =1,
                                                size=(self.n_cluster_data,2))
            self.dataset.append(cluster_data)
        #1.1 print(cluster_data.shape)
        self.dataset = np.vstack(self.dataset)
        #1.2 print(self.dataset.shape)

    def vis_ds(self):
        fig, ax = plt.subplots(figsize =(10,10))
        ax.scatter(self.dataset[:,0], self.dataset[:,1])
        
        ax.set_xlim([-12,12])
        ax.set_ylim([-12,12])
        return ax

    def get_ds(self): return self.dataset


n_cluster, n_cluster_data =5, 100

dataset_generator = DatasetGnerator(n_cluster,n_cluster_data)
dataset = dataset_generator.get_ds()
ax = dataset_generator.vis_ds()

#print(dataset.shape) 500,2 중에 랜덤한 k개 고르기

K = 8

## 아래 방법은 사용하면 안됨 why? 중복될 수 있기 때문 
# indices = np.random.randint(0,n_cluster_data* n_cluster,(K,))  

# centroids = dataset[indices]
# print(indices)

# ax.scatter(centroids[:,0], centroids[:,1],
#             color ='red',s=300)
# plt.show()
indices = np.arange(n_cluster*n_cluster_data) #인덱스 499까지 깔고 
np.random.shuffle(indices) #셔플링한다음
indices = indices[:K]  # 뽑아줌 

centroids = dataset[indices]

ax.scatter(centroids[:,0], centroids[:,1],
             color ='red',s=300)
plt.show()