package edu.umd.cs.psl.evaluation.statistics;

import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.evaluation.statistics.filter.AtomFilter;
import edu.umd.cs.psl.evaluation.statistics.filter.MaxValueFilter;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.util.database.Queries;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/MulticlassPredictionComparator.class */
public class MulticlassPredictionComparator implements ResultComparator {
    private final Database predDB;
    private Database truthDB = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/MulticlassPredictionComparator$Example.class */
    public class Example {
        private final GroundTerm[] terms;

        public Example(GroundTerm[] groundTermArr) {
            this.terms = groundTermArr;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof Example)) {
                return false;
            }
            Example example = (Example) obj;
            if (this.terms.length != example.terms.length) {
                return false;
            }
            for (int i = 0; i < this.terms.length; i++) {
                if (!this.terms[i].equals(example.terms[i])) {
                    return false;
                }
            }
            return true;
        }

        public int hashCode() {
            return this.terms[0].hashCode();
        }

        public String toString() {
            return this.terms[0].toString();
        }
    }

    public MulticlassPredictionComparator(Database database) {
        this.predDB = database;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setBaseline(Database database) {
        this.truthDB = database;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setResultFilter(AtomFilter atomFilter) {
    }

    public PredictionStatistics compare(Predicate predicate, Map<GroundTerm, Integer> map, int i) {
        int size = map.size();
        int[][] iArr = new int[size][size];
        Map<Example, Integer> allMaxScoreAtoms = getAllMaxScoreAtoms(this.predDB, predicate, map, i);
        Map<Example, Integer> allMaxScoreAtoms2 = getAllMaxScoreAtoms(this.truthDB, predicate, map, i);
        for (Map.Entry<Example, Integer> entry : allMaxScoreAtoms.entrySet()) {
            Example key = entry.getKey();
            if (!allMaxScoreAtoms2.containsKey(key)) {
                throw new RuntimeException("Missing ground truth for example " + key.toString());
            }
            int intValue = entry.getValue().intValue();
            int[] iArr2 = iArr[allMaxScoreAtoms2.get(key).intValue()];
            iArr2[intValue] = iArr2[intValue] + 1;
        }
        return new MulticlassPredictionStatistics(new ConfusionMatrix(iArr));
    }

    private Map<Example, Integer> getAllMaxScoreAtoms(Database database, Predicate predicate, Map<GroundTerm, Integer> map, int i) {
        HashMap hashMap = new HashMap();
        Iterator<GroundAtom> filter = new MaxValueFilter(predicate, i).filter(Queries.getAllAtoms(database, predicate).iterator());
        while (filter.hasNext()) {
            GroundAtom next = filter.next();
            if (next.getValue() == 0.0d) {
                throw new RuntimeException("Max value does not exist.");
            }
            GroundTerm[] arguments = next.getArguments();
            GroundTerm[] groundTermArr = new GroundTerm[arguments.length - 1];
            int i2 = 0;
            for (int i3 = 0; i3 < arguments.length; i3++) {
                if (i3 != i) {
                    int i4 = i2;
                    i2++;
                    groundTermArr[i4] = arguments[i3];
                }
            }
            hashMap.put(new Example(groundTermArr), map.get(arguments[i]));
        }
        return hashMap;
    }
}
