Zuverlässigere Vorhersagen eines Bildklassifikators durch Nutzung von Domänenwissen

Lead Machine Learning Engineer @ Chrono24

Bei der Arbeit mit Klassifikationsmodellen stellt sich in der Praxis oft die Frage, inwieweit wir den zugeordneten Kategorien eines Modells vertrauen können. Wenn wir Datensätze für das Training eines Modells erzeugen, werden oftmals nicht vollständig annotierte Daten außen vor gelassen und strukturelle Informationen, die über einzelne Datensätze hinausgehen, nicht verwendet.

In diesem Artikel möchten wir Euch aufzeigen, wie man ein Klassifikationssystem mit Informationen aus der hierarchischen Struktur der Daten verbessern kann. Auch frei verfügbare Datensätze wie z. B. CIFAR-10 und Image-Net besitzen eine solche Struktur. Das Ziel ist es, ein zusätzliches Plausibilitätsmodell zu trainieren, dass diese Strukturinformationen verwertet, um eine robustere Klassifikation zu ermöglichen. Bei Chrono24 wird dies beispielsweise zur Verbesserung des Watch Scanners angewandt.

Hierarchie in den Daten

Die Erforschung von neuen Modellarchitekturen basiert fast ausschließlich auf Datensätzen, deren Struktur flach ist. Das heißt, es gibt nur eine Ebene der Klassen und deren Hierarchie findet keine Verwendung. Beispielsweise werden bei Benchmarks auf dem ImageNet Datensatz üblicherweise 1000 verschiedene Kategorien (Katze, Hund, Kino, Bett…) verwendet, die durch das Modell unterschieden werden sollen. ImageNet ist jedoch wesentlich detaillierter und bietet einen ganzen Kategoriebaum mit Oberkategorien (Tier, Gebäude, Möbel…) auf verschiedenen Ebenen. Es gibt einige Modelle, welche die Hierarchie der Kategorien ausnutzen, wie HD-CNNs und Tree-CNNs, aber dies ist nur ein sehr kleiner Bereich der Forschung.

In der Praxis haben wir es oft mit Anwendungen zu tun, bei denen die Anzahl an Trainingsdaten anfangs gering ist, aber mit der Zeit wächst. Die Anzahl der Bilder, die wir pro Klasse zum Training haben, variiert und die Genauigkeit für einige Kategorien ist schlechter als für andere. Im Produktivbetrieb eines Klassifikationsmodells kann es passieren, dass ganz neue und bisher unbekannte Klassen auftreten, die zwar im originalen Trainingsdatensatz nicht vorhanden waren, aber zumindest durch eine Überkategorie abgedeckt wurden. Durch Reduktion der Kategorien auf die unterste Ebene kann der Klassifikator diese Information aber nicht nur Verfügung stellen. Mit anderen Worten: Das Modell liefert falsche Vorhersagen für diese Daten.

Confidence Scores sind oft unzuverlässig

Neuronale Netze, die als Klassifikatoren dienen, liefern meist eine Wahrscheinlichkeit pro Klasse als Ergebnis. Dies geschieht durch die Anwendung einer Soft-Max Funktion. Diese Wahrscheinlichkeit wird oft als Sicherheit oder „Confidence“ interpretiert. Allerdings ist das nicht immer korrekt. Neuronale Netze tendieren dazu sich selbst zu überschätzen und hohe Scores vorherzusagen, insbesondere in Bereichen, in denen die Datenbasis klein ist (Mehr dazu hier und hier).

Wenn ein Modell jedoch mit neuen Daten konfrontiert wird, liefert es uns häufig hohe Wahrscheinlichkeitswerte, auch wenn die Vorhersage völlig daneben liegt. Dieses Problem ist noch schwieriger zu lösen, wenn es sich um eine „fine-grained“ Bildklassifikation handelt , bei der kleine Details in Bildern einen Unterschied machen können (z. B. die Klassifikation von Pflanzen). Auch der Watch Scanner in der Chrono24 App gehört in diese Kategorie. Im zugrundeliegenden Datensatz finden sich über 15.000 Klassen, die sich oftmals nur in marginalen Details wie z. B. Abweichungen im Ziffernblatt oder der Existenz einer Datumsanzeige unterscheiden. Diese Beispiele sind generell schwer zu differenzieren, auch Menschen benötigen oft einen zweiten Blick um Unterschiede zu erkennen.

Trainiert man nun ein neuronales Netz darauf, den Uhrenbildern Referenznummern zuzuordnen (also auch auf der untersten Ebene der Klassenstruktur), kann dieses erstmal nicht die Marke, den Uhrentyp (z. B. Taucheruhr, Fliegeruhr) oder das Modell (Daytona, Submariner) bestimmen. Diese Informationen ergeben sich lediglich dann, wenn die richtige Referenznummer erkannt wird, welche die exakte Uhr relativ eindeutig identifiziert. War die Uhr mit dieser Referenz allerdings nicht in den Trainingsdaten enthalten, wird das Modell die nächstähnliche Uhr vorschlagen, unabhängig davon ob Marke oder Modell dazu passen. Das kann für den Nutzer zu schwer nachvollziehbaren Falschzuordnungen führen.

Plausibilitätsmodell

Die Idee der Plausibilitätsprüfung ist es, ein zweites Modell zu trainieren, dass auf einer anderen Hierarchieebene klassifiziert (z. B. Erkennung des Uhrenmodells). Ziel ist es, Rückschlüsse auf die Güte der Erkennung zu treffen, indem Informationen aus der Struktur der Daten genutzt werden.

Die Grundannahme ist, dass ein Plausibilitätsmodell von der größeren Menge an Trainingsdaten pro Klasse profitiert und mehr generalisiert, wodurch ein Overfitting verhindert wird. Außerdem können Daten für das Training verwendet werden, die unvollständig annotiert sind. Das ist beispielweise der Fall, wenn zu einem Bild nur das Uhrenmodell, nicht aber die Referenznummer bekannt ist. Dieser Fall kommt vor allem dann vor, wenn Benutzerfeedback verwendet wird, um die Trainingsdaten zu verbessern und ein Nutzer bei einem Feedback nicht alle korrekten Daten zurückmeldet.

Trainieren eines Modells zur Plausibilitätsprüfung

Das eigentliche Modell wird auf den Ziel-Labels „Referenznummer“ trainiert, da diese auf der höchsten Granularität liegen (Blätter in der Hierarchie). Das Plausibilitätsmodell wird auf einem Label höherer Ebene trainiert (Elternknoten oder verwandte Knoten in der Hierarchie). Am Beispiel des Watch Scanners wird das Modell auf die Erkennung von Referenzen trainiert und das Plausibilitätsmodell auf die Erkennung von Uhrenmodellen.

Am Rande bemerkt: Man könnte meinen, dass es viel einfacher wäre, die Hierarchie direkt mit mehreren Modellen zu modellieren, z. B. eines für die Marke und nach der Erkennung der Marke ein weiteres für das Modell und zuletzt ein weiteres für die Referenznummer. Und ja, das ist möglich, wird so bei kleineren Datensätzen auch häufig getan. Allerdings schaffen wir damit ein großes Problem, denn wir müssen für jeden Knoten außer der Blattknoten im Graph ein Modell trainieren. Das schlicht nicht praktikabel, wenn man keine unbegrenzten Ressourcen hat.

Testen auf Plausibilität

Wenn wir diese beiden Modelle trainiert haben und die Klassifikation ausführen, bekommen wir Wahrscheinlichkeitswerte, die für jede im Modell enthaltene Klasse (Referenznummer oder Uhrenmodell) einen Wert zwischen 0 und 100% liefert, wobei die Summe immer 100% ergibt (Softmax Funktion). Die Klasse mit dem höchsten Wert ist in der Regel die korrekte Vorhersage.

Aufgrund der reduzierten Anzahl von Klassen und der höheren Anzahl von Bildern pro Klasse hat unser Plausibilitätsmodell in der Regel eine viel bessere Genauigkeit als das Referenznummernmodell. Betrachten wir die besten Vorhersagen für beide Modelle, bekommen wir jeweils einen Wahrscheinlichkeitswert für eine Referenznummer und für das Uhrenmodell:

  • Wenn beide Werte hoch sind und die Referenz (z. B. Ref. 116506) ein Label vorhersagt, das mit der Plausibilitätsvorhersage (z. B. Daytona) übereinstimmt, können wir davon ausgehen, dass die Vorhersage korrekt ist. Bei der Referenznummer handelt es sich um eine Rolex Daytona. Beispielvorhersage: 98% Ref. 116506 / 89% Daytona.
  • Wenn beide Werte niedrig sind, wissen wir, dass das Ergebnis nicht zuverlässig ist. Meist stimmen hier Referenz und Modell nicht überein. Beispielvorhersage: 5% Ref. 116506 / 7% Speedmaster
  • Wenn die Wahrscheinlichkeit für die Referenznummer niedrig ist, aber das Plausibilitätsmodell eine hohe Wahrscheinlichkeit vorhersagt, haben wir möglicherweise eine fehlende Referenz in unseren Trainingsdaten gefunden (z. B. Ref. 116520). Beispielvorhersage: 98% Daytona / 25% Ref. 116506
  • Wenn die Vorhersage des Referenznummermodells hoch ist und das Plausibilitätsmodell eine niedrige Wahrscheinlichkeit angibt, wird es schwieriger. Oftmals handelt es sich um Bilder, die mit der Domäne nichts zu tun haben (keine Uhr abgebildet), aber das Bild ähnliche Merkmale aufweist oder um Datenfehler in den Trainingsdaten. Diese Konstellation ist besonders wichtig, um zukünftig den Datensatz zu verbessern. Beispielvorhersage: 18% Daytona / 98% Ref. 116506

Und schließlich der interessanteste Fall: Nehmen wir an, wir haben eine sehr hohe Wahrscheinlichkeit für unser Plausibilitätsmodell und eine mittlere Wahrscheinlichkeit für unser Referenznummernmodell, aber die Vorhersagen sind nicht plausibel: Beispiel: 98% Daytona / 47% Ref. 311.30.42.30.01.006. Bei der zugeordneten Referenznummer handelt es sich um eine Uhr der Marke Omega. Das widerspricht der Vorhersage des Plausibilitätsmodells, welches mit großer Sicherheit angibt eine Rolex Daytona erkannt zu haben.

Schauen wir uns allerdings die zweitbeste erkannte Referenznummer an (46% für Ref.6239) sehen wir, dass es sich dabei tatsächlich um eine ältere Rolex Daytona handelt. Obwohl die Vorhersage nur die zweitbeste ist, ist es hier plausibel anzunehmen, dass diese korrekt ist, da diese mit dem Plausibilitätsmodell übereinstimmt. Im Allgemeinen können wir also eine Logik (oder gar ein ML-Modell) anwenden, welche die in den Top-N-Vorhersagen des Referenznummermodells sucht, die zu einem gegebenen Modell passen, ohne dass wir sehr viele CNNs trainieren müssen.

Fazit

Die Verwendung eines Plausibilitätsmodells, bei Domänen in denen Hierarchieinformationen in den Daten verfügbar sind, kann Vorhersagen zuverlässiger machen. Wenn das System widersprüchliche Scores liefert, können wir davon ausgehen, dass die Vorhersage falsch oder zumindest ungenau ist. Außerdem können wir mehr Einblicke in mögliche Fehler des Modells erhalten, fehlende Labels finden oder sogar die Vorhersagen korrigieren, indem wir nicht nur die Top-Vorhersage, sondern auch die Top-N-Vorhersagen berücksichtigen. Der Nachteil ist, dass wir ein zusätzliches Modell benötigen und mehr Aufwand betreiben müssen um die Validierung des Gesamtmodells zu gewährleisten.