
Jednym z głównych problemów pojawiających się w uczeniu maszynowym jest overfitting, czyli przeuczenie (nazywane też nadmiernym dopasowaniem). Przeciwnym zjawiskiem jest underfitting – niedouczenie (albo niedostateczne dopasowanie). Zobrazujmy sobie to na przykładzie podejścia do wykonywania jakiegoś zadania – możemy wtedy popełnić dwa błędy: albo zbytnio uprościmy problem i przedstawimy rozwiązanie zbyt proste, albo zbytnio skomplikujemy sobie zadanie i nasze rozwiązanie będzie za nadto złożone.
Przeszkody te wynikają z faktu, iż uczenie maszynowe jest ciągłym balansowaniem między optymalizacją a generalizacją. Poprzez optymalizację rozumiemy tu proces regulowania modelu w celu uzyskania możliwie najlepszej efektywności na danych treningowych. Generalizacja zaś odpowiada temu, jak dobrze model działa na danych, których nigdy wcześniej nie widział. Celem jest oczywiście uzyskanie jak najlepszej generalizacji, lecz tego nie możemy bezpośrednio kontrolować, gdyż możemy jedynie uczyć model na danych treningowych. Jeśli na tym etapie „przesadzimy”, czyli z jakiegoś powodu doprowadzimy do przetrenowania modelu, nie będzie on zdolny do generalizowania, a jedynie zapamięta, to co zobaczył podczas treningu. Często obserwuje się zależność: proste modele mają tendencję do underfittingu, a modele skomplikowane, czyli takie, które definiowane są przez dużą liczbę parametrów, do overfittingu. Celem uczenia maszynowego jest znalezienie złotego środka (zaznaczony pionową linią na poniższym wykresie), czyli takiego modelu, który nie będzie ani zbyt prosty, ani zbyt złożony, a który właściwie odda istotę dostarczonych mu danych.
Na początku treningu (lub – patrząc od strony stopnia skomplikowania modelu – w miarę zwiększania złożoności modelu) optymalizacja i generalizacja są skorelowane: im mniejsza strata na danych treningowych, tym mniejsza strata na danych testowych. Mamy wtedy do czynienia z niedouczeniem, a więc model ma pole do poprawy. Lecz po przejściu pewnej liczby iteracji na danych treningowych generalizacja przestaje się poprawiać, a metryki walidacyjne się pogarszają – model się przeucza. Zaczyna uczyć się wzorców, które są specyficzne dla danych treningowych, a w odniesieniu do niewidzianych danych niestety mogą wprowadzać w błąd lub być bez znaczenia.
Jeśli wykorzystamy do treningu wszystkie dostępne dane, to nie będziemy w stanie określić czy i kiedy pojawia się overfitting. Jednym ze sposobów określenia czy model jest przeuczony, jest wydzielenie z dostępnych danych mniejszego zbioru – zbioru testowego – danych, których model pod żadnym pozorem nie może zobaczyć w trakcie treningu (pozostałą część danych zwiemy rzecz jasna zbiorem treningowym). Taki zbiór testowy może dzięki temu wystąpić jako bezstronny „sędzia” efektywności modelu wytrenowanego na zbiorze treningowym, a w razie pozytywnej oceny możemy być pewni, że model dobrze generalizuje na niewidzianych wcześniej danych, a nie tylko zapamiętuje zbiór treningowy.
Spójrzmy jeszcze na ten problem na przykładzie regresji. Na poniższym rysunku czerwone kropki reprezentują dane treningowe, czarne linie to dopasowane krzywe regresji, a niebieskie kwadraty to dane testowe. Dla najprostszego modelu (po lewej) błąd uzyskany podczas treningu (umownie jako suma odległości od krzywej, czyli długości czerwonych linii) jest największy, dla modelu najbardziej złożonego (po prawej) jest bardzo niski, a dla modelu o średniej złożoności (w środku) niski. Jeśli jednak spojrzymy na błąd uzyskany na danych testowych (niebieskie linie) widać, że dla modelu o najmniejszym stopniu złożoności jest on nadal wysoki, dla najbardziej skomplikowanego modelu błąd jest tym razem wysoki, a dla modelu pośredniego wciąż jest niski. Zatem model po lewej jest niedouczony, a model po prawej przeuczony – jest on przysłowiową armatą na wróbla. Optymalny model zarówno na danych treningowych, jak i testowych uzyskał niskie wartości błędów.
Tak jak zbiór testowy nie może być wykorzystany do treningu, tak też nie może służyć do podejmowania decyzji odnośnie hiperparametrów modelu. W tym celu wydzielimy z dostępnych danych trzeci zbiór – zbiór walidacyjny. Wydzielenie dodatkowych zbiorów ze zbioru treningowego może sprawi, że model otrzyma za mało danych, na których może się uczyć. Taki model jest niezdolny do generalizowania na nowych danych, co jest jedną z przyczyn pojawiania się nadmiernego dopasowania. Gdy model ma zaś dostęp do nieograniczonej ilości danych, będzie przygotowany na każdy scenariusz i nigdy się nie przeuczy. Trzeba zatem rozważnie dzielić dostępne dane, tak by nasze działania mające ograniczyć overfitting, nie przyczyniały się do jego powstawania.
Kiedy pozyskanie większej ilości danych nie jest możliwe, najlepszym rozwiązaniem jest wtedy ograniczenie złożoności modelu. W takim wypadku będzie on zmuszony do skupienia się na najbardziej widocznych wzorach, które będą miały największe szanse na dobre generalizowanie. Taka praktyka zwana jest regularyzacją. Do oceny modelu brana będzie wtedy nie tylko jego efektywność, ale i złożoność, tak by finalnie model dawał jak najlepsze rezultaty i był możliwie najprostszy.
Omówmy pokrótce ten sposób ograniczania overfittingu. W skrócie chodzi o mierzenie wydajności i złożoności dwiema różnymi funkcjami błędu i dodanie ich do siebie w celu otrzymania jednej funkcji błędu, której minimalizacja w czasie treningu zapewni dobrą wydajność, jak i odpowiednią złożoność modelu. Złożoność modelu możemy określić opierając się na zasadzie, że im więcej i im wyższe wartości współczynników (wag), tym model może być bardziej złożony. I tu – bez wchodzenia w matematyczne szczegóły – pojawiają się dwie miary określające stopień skomplikowania modelu, zwane normami L1 i L2, czyli odpowiednio:
Gdy trenujemy model regresji z użyciem regularyzacji z normą L1, można się spotkać z nazwą: regresja lasso. W przypadku normy L2 mówi się o regresji ridge. Funkcje błędu wyglądają następująco:
Jako że proces trenowania modelu obejmuje minimalizację funkcji błędu najbardziej jak to możliwe, model trenowany z wykorzystaniem regularyzacji powinien mieć wysoką wydajność i niską złożoność. Jednak pojawia się tu swoiste przeciąganie liny – chcąc sprawić, żeby model miał lepszą wydajność, możemy bardziej go skomplikować, i odwrotnie – chcąc zmniejszyć złożoność modelu możemy doprowadzić do sytuacji, że będzie on miał gorszą efektywność. Na szczęście większość algorytmów posiada hiperparametr związany z regularyzacją – lambda, którego zadaniem jest określenie czy model w czasie treningu powinien bardziej skupić się na efektywności, czy niskim stopniu skomplikowania. Im wyższa wartość tego parametru, tym większy wpływ regularyzacji na trening modelu. Jeśli zaś ustawimy go na zero, model zostanie wytrenowany bez regularyzacji.
A zatem, prowadząc trening z regularyzacją, finalny model będzie mniej złożony: w przypadku regularyzacji L1 (lasso) model będzie miał mniej współczynników – niektóre zostaną wyzerowane; w przypadku regularyzacji L2 (ridge) wartości współczynników zostaną zmniejszone, lecz nigdy wyzerowane. Podsumowując: jeśli mamy wiele atrybutów (cech) i chcemy pozbyć się części z nich, regularyzacja L1 będzie idealna. Jeśli mamy tylko kilka atrybutów i wszystkie wydają się konieczne, wtedy regularyzacja L2 jest tym, czego potrzebujemy.
Źródła
Luis Serrano – Grokking Machine Learning, Manning Publications
François Chollet – Deep Learning with Python
https://www.v7labs.com/blog/overfitting
Digital Fingerprints S.A. ul. Gliwicka 2, 40-079 Katowice. KRS: 0000543443, Sąd Rejonowy Katowice-Wschód, VIII Wydział Gospodarczy, Kapitał zakładowy: 4 528 828,76 zł – opłacony w całości, NIP: 525-260-93-29
Biuro Informacji Kredytowej S.A., ul. Zygmunta Modzelewskiego 77a, 02-679 Warszawa. Numer KRS: 0000110015, Sąd Rejonowy m.st. Warszawy, XIII Wydział Gospodarczy, kapitał zakładowy 15.550.000 zł opłacony w całości, NIP: 951-177-86-33, REGON: 012845863.
Biuro Informacji Gospodarczej InfoMonitor S.A., ul. Zygmunta Modzelewskiego 77a, 02-679 Warszawa. Numer KRS: 0000201192, Sąd Rejonowy m.st. Warszawy, XIII Wydział Gospodarczy, kapitał zakładowy 7.105.000 zł opłacony w całości, NIP: 526-274-43-07, REGON: 015625240.