import math
import random
from . import basic_function
from tqdm import tqdm

class K_means:
    def __init__(self, x, center_num):
        center = []
        indices = list(range(len(x)))
        random.shuffle(indices)

        x_random = [x[index] for index in indices]

        for i in range(center_num):
            center.append(x_random[i])
        self.center_num = center_num
        self.center = center
        self.x = x
        self.center_dimension = len(x[0])

    def train(self, epochs=100):
        for _ in tqdm(range(epochs)):
            center_list = []
            for i in range(self.center_num):
                center_list.append([])
            for i in range(len(self.x)):
                dist_arr = []
                for k in range(self.center_num):
                    dist = basic_function.o_dist(self.x[i], self.center[k])
                    dist_arr.append(dist)

                min_dist = min(dist_arr)
                min_index = dist_arr.index(min_dist)

                center_list[min_index].append(self.x[i])

            for i in range(self.center_num):
                for j in range(self.center_dimension):


                    elements = [sublist[j] for sublist in center_list[i]]
                    if len(elements) != 0:
                        self.center[i][j] = sum(elements) / len(elements)

        return self.center

    def predict(self):
        center_list = []
        for i in range(self.center_num):
            center_list.append([])
        for i in range(len(self.x)):
            dist_arr = []
            for k in range(self.center_num):
                dist = basic_function.o_dist(self.x[i], self.center[k])
                dist_arr.append(dist)

            min_dist = min(dist_arr)
            min_index = dist_arr.index(min_dist)

            center_list[min_index].append(self.x[i])

        return center_list


