import Utils from "@/computation/utils.js";

export class ClusterMaker {
    constructor(k, maxIterations = 100) {
        this.k = k;
        this.maxIterations = maxIterations;
    }

    fit(data) {
        let centroids = this.initializeCentroids(data);

        for (let iteration = 0; iteration < this.maxIterations; iteration++) {
            const clusters = this.assignToClusters(data, centroids);
            const newCentroids = this.calculateCentroids(data, clusters);

            if (this.hasConverged(centroids, newCentroids)) {
                break;
            }

            centroids = newCentroids;
        }

        return {
            centroids,
            clusters: this.assignToClusters(data, centroids),
        };
    }

    initializeCentroids(data) {
        const centroids = [];
        const shuffledData = data.slice();
        for (let i = 0; i < this.k; i++) {
            const randomIndex = Math.floor(Math.random() * shuffledData.length);
            centroids.push(shuffledData.splice(randomIndex, 1)[0]);
        }
        return centroids;
    }

    assignToClusters(data, centroids) {
        const clusters = new Array(this.k).fill().map(() => []);

        for (const point of data) {
            const distances = centroids.map(centroid => Utils.calculateDistance(point, centroid));
            const nearestCentroidIndex = distances.indexOf(Math.min(...distances));
            clusters[nearestCentroidIndex].push(point);
        }

        return clusters;
    }

    calculateCentroids(data, clusters) {
        return clusters.map(cluster => {
            const clusterSize = cluster.length;
            const clusterSum = cluster.reduce((acc, point) => {
                return {
                    x: acc.x + point.x,
                    y: acc.y + point.y
                };
            }, { x: 0, y: 0 });
            return {
                x: clusterSum.x / clusterSize,
                y: clusterSum.y / clusterSize
            };
        });
    }

    hasConverged(centroids, newCentroids) {
        return centroids.every((centroid, i) => Utils.calculateDistance(centroid, newCentroids[i]) < 1e-6);
    }
}