package edu.umd.cs.psl.evaluation.statistics;

import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.evaluation.statistics.filter.AtomFilter;
import edu.umd.cs.psl.model.argument.GroundTerm;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.ObservedAtom;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.util.database.Queries;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/DiscretePredictionComparator.class */
public class DiscretePredictionComparator implements PredictionComparator {
    public static final double DEFAULT_THRESHOLD = 0.5d;
    private final Database result;
    private Database baseline = null;
    private AtomFilter resultFilter = AtomFilter.NoFilter;
    private double threshold = 0.5d;
    int tp;
    int fn;
    int tn;
    int fp;
    Map<GroundAtom, Double> errors;
    Set<GroundAtom> correctAtoms;

    public DiscretePredictionComparator(Database database) {
        this.result = database;
    }

    public void setThreshold(double d) {
        this.threshold = d;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setBaseline(Database database) {
        this.baseline = database;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setResultFilter(AtomFilter atomFilter) {
        this.resultFilter = atomFilter;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionComparator
    public DiscretePredictionStatistics compare(Predicate predicate) {
        countResultDBStats(predicate);
        return new DiscretePredictionStatistics(this.tp, this.fp, this.tn, this.fn, this.threshold, this.errors, this.correctAtoms);
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionComparator
    public DiscretePredictionStatistics compare(Predicate predicate, int i) {
        countResultDBStats(predicate);
        Iterator<GroundAtom> filter = this.resultFilter.filter(Queries.getAllAtoms(this.baseline, predicate).iterator());
        while (filter.hasNext()) {
            GroundAtom next = filter.next();
            if (!this.errors.containsKey(next) && !this.correctAtoms.contains(next)) {
                double d = next.getValue() >= this.threshold ? 1.0d : 0.0d;
                if (d != 0.0d) {
                    this.errors.put(this.result.getAtom(next.getPredicate(), next.getArguments()), Double.valueOf(d));
                    this.fn++;
                }
            }
        }
        this.tn = ((i - this.tp) - this.fp) - this.fn;
        return new DiscretePredictionStatistics(this.tp, this.fp, this.tn, this.fn, this.threshold, this.errors, this.correctAtoms);
    }

    private void countResultDBStats(Predicate predicate) {
        this.tp = 0;
        this.fn = 0;
        this.tn = 0;
        this.fp = 0;
        this.errors = new HashMap();
        this.correctAtoms = new HashSet();
        Iterator<GroundAtom> filter = this.resultFilter.filter(Queries.getAllAtoms(this.result, predicate).iterator());
        while (filter.hasNext()) {
            GroundAtom next = filter.next();
            GroundTerm[] groundTermArr = new GroundTerm[next.getArity()];
            for (int i = 0; i < groundTermArr.length; i++) {
                groundTermArr[i] = next.getArguments()[i];
            }
            GroundAtom atom = this.baseline.getAtom(next.getPredicate(), groundTermArr);
            if (atom instanceof ObservedAtom) {
                boolean z = next.getValue() >= this.threshold;
                boolean z2 = atom.getValue() >= this.threshold;
                if ((z && z2) || !(z || z2)) {
                    if (z) {
                        this.tp++;
                    } else {
                        this.tn++;
                    }
                    this.correctAtoms.add(next);
                } else if (z) {
                    this.fp++;
                    this.errors.put(next, Double.valueOf(1.0d));
                } else {
                    this.fn++;
                    this.errors.put(next, Double.valueOf(-1.0d));
                }
            }
        }
    }
}
