FSGeek

Algorytm kmeans bez bibliotek. Zobacz jakie to proste!

By Aleksander Patschek on Oct 20, 2020

Algorytm k-means jest jednym z podstawowych algorytmów uczenia nienadzorowanego. Należy do algorytmów analizy skupień (inaczej grupowania, klasteryzacji) i pozwala na podzielenie elementów na określoną ilość klas ze względu na podobieństwo. Jest on na tyle prosty, że warto go własnoręcznie zaimplementować.

Społeczność na Facebooku

Jeśli jesteś zainteresowany tematem uczenia maszynowego, to zapraszam do grupy na fb. Aby ją odwiedzić kliknij tutaj

Centroid

W przypadku tego algorytmu będziemy się spotykali z określeniem centroid. Będzie on oznaczał punkt, który jest reprezentantem danej grupy, czyli będzie środkiem tej grupy. Oznacza to tyle, że będziemy mieć tyle centroidów ile grup oraz centroid będzie decydował, do jakiej grupy będzie należał punkt. Również ten punkt nie będzie stały, tylko będzie się zmieniał wraz z postępem prac algorytmu tak, aby jak najlepiej reprezentować elementy, które się w nim znajdują. W najprostszym podejściu będzie to średnia wszystkich elementów należących do danej grupy.

Kroki algorytmu

Zanim przejdziemy do tego, co lubimy najbardziej, czyli kodu, warto wiedzieć co się dzieje wewnątrz algorytmu. Możemy opisać to w następujących krokach:

  1. Inicjalizacja struktur danych
  2. Wybór początkowych centroidów
  3. Wyliczenie metryki dla każdego elementu
  4. Przydzielenie elementów do poszczególnych grup
  5. Wyliczenie nowych centroidów
  6. Jeśli zrobiono odpowiednią liczbę iteracji to koniec, inaczej przejdź do punktu 3

Jak widać, algorytm nie jest trudny, więc warto samemu to zaimplementować. Przejdę teraz każdy krok po kolei, by bardziej szczegółowo to omówić i pokazać wraz z kodem.

Inicjalizacja danych

Dla tego algorytmu nie trzeba tworzyć bardzo skomplikowanych struktur. Potrzebujemy obiekt, w którym będziemy przechowywać dane podzielone na grupy np.:

const group = {
	'0':[],
	'1':[]
} 

Najważniejsza rzecz to wyliczenie początkowych centroidów. Mamy tutaj kilka opcji:

  • dać użytkownikowi możliwość podania centroidów
  • jakoś szybko wyliczyć środki na podstawie danych
  • wybranie losowych danych jako nasze początkowe centroidy

Ja dla podstawowej wersji postanowiłem wybrać losowe punkty. Należy pamiętać, że taka inicjalizacja może dawać różne wartości końcowe, ze względu na wybrane punkty. Przykładowy kod

let centroids = new Array(numberOfCentroids).fill(0).map(()=>{
        return data[Math.floor(Math.random()*data.length)]
    })

Wyliczanie metryki

Następny ważny punkt to sposób wyliczania metryki. Będzie miała ona wpływ na przydzielanie elementów do konkretnych grup. Jest to funkcja, która zwróci nam miarę odległości pomiędzy dwoma elementami zbioru. Ja w tym przykładzie korzystam z metryki SED (squared Euclidean distance) czyli metryka euklidesowa podniesiona do kwadratu.

const calculateSquareError = (x, r)=>{
    let sum = 0;
    for(let i=0; i<x.length; i++){
        sum += Math.pow(x[i] - r[i],2);
    }
    return sum;
}

Aktualizacja centroidów

Za każdym razem, gdy rozdzielimy dane do poszczególnych grup, musimy zaktualizować nasze centroidy. Nowe centroidy będą średnią arytmetyczną wszystkich danych w danej grupie. Dzięki temu po każdej iteracji będziemy przesuwali się z losowych punktów do centroidów, które faktycznie reprezentują elementy z grupy. Przykład aktualizacji centroidów:

const recalculateCentroids = (centroids, groups)=> {
    return Object.values(groups).map((group, index)=>{
        const sum = group.reduce((value, point)=>{
            return {
                x: value.x + point[0],
                y: value.y + point[1],
            }
        },{x:0, y:0})

        return [sum.x/group.length, sum.y/group.length]
    })
}

Głowna pętla algorytmu

Została główna pętla algorytmu.

for(let i=0; i<numberOfIterations; i++) {
        groups = new Array(numberOfCentroids).fill(0).reduce((res, _, index) => {
            return {
                ...res,
                [index]: []
            }
        }, {});

        for (element of data) {
            const {group} = centroids.reduce((currGroup, value, index) => {
                const error = calculateSquareError(element, value);
                if (error < currGroup.error) {
                    return {
                        group: index,
                        error: error,
                    }
                }
                return currGroup;

            }, {group: 0, error: Infinity})

            groups[group] = [
                ...groups[group],
                element
            ]
        }

        centroids = recalculateCentroids(centroids, groups);
    }

Na początku każdej pętli resetujemy przydział elementów do grup - dzięki temu wiemy, że nie mamy duplikatów. Następnie w pętli for przechodzimy po każdym elemencie i sprawdzamy, do jakiej grupy powinien należeć. To, do jakiej grupy będzie należeć element, zależy od zdefiniowanej przez nas metryki. Im mniejsza odległość do centroidu, tym mniejszy błąd. Przyporządkowujemy element do najbliższego centroidu. Po przejściu przez wszystkie elementy musimy zaktualizować centroidy i możemy albo rozpocząć kolejną iterację, albo zakończyć działanie algorytmu.

Pełny kod

Jak ulepszyć?

Tak jak wspomniałem, wybór punktów początkowych algorytmu może mieć wpływ na wyniki. Można więc ulepszyć nasz algorytm wyboru punktów początkowych, dając bardziej zaawansowane algorytmy, które spróbują oszacować najlepsze punkty. Należy przy tym pamiętać, że nie możemy robić zbyt zaawansowanych wyliczeń, ponieważ spowolni to znacząco pracę głównej części algorytmu. Inna opcja to danie użytkownikowi możliwości podania własnych punktów i wyliczanie grup na ich podstawie. Kolejna rzecz, którą można zmienić to sposób wyliczania metryki pomiędzy dwoma punktami. Na razie mamy też określoną liczbę iteracji a algorytm może znaleźć odpowiednie punkty wcześniej. Można więc dołożyć bardziej zaawansowane wykrywanie momentu zakończenia algorytmu.

Polityka prywatności
© Copyright 2024 by Blog FSGeek
Ikony pochodzą z Icons8