package edu.umd.cs.psl.evaluation.statistics;

import edu.umd.cs.psl.application.learning.weight.random.SliceRandOM;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.evaluation.statistics.filter.AtomFilter;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.util.database.Queries;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/ContinuousPredictionComparator.class */
public class ContinuousPredictionComparator implements ResultComparator {
    private static final Logger log = LoggerFactory.getLogger(ContinuousPredictionComparator.class);
    private final Database result;
    private AtomFilter resultFilter;
    private static /* synthetic */ int[] $SWITCH_TABLE$edu$umd$cs$psl$evaluation$statistics$ContinuousPredictionComparator$Metric;
    private Database baseline = null;
    private Metric metric = Metric.MSE;

    /* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/ContinuousPredictionComparator$Metric.class */
    public enum Metric {
        MSE,
        MAE;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Metric[] valuesCustom() {
            Metric[] valuesCustom = values();
            int length = valuesCustom.length;
            Metric[] metricArr = new Metric[length];
            System.arraycopy(valuesCustom, 0, metricArr, 0, length);
            return metricArr;
        }
    }

    public ContinuousPredictionComparator(Database database) {
        this.resultFilter = AtomFilter.NoFilter;
        this.result = database;
        this.resultFilter = AtomFilter.NoFilter;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setBaseline(Database database) {
        this.baseline = database;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.ResultComparator
    public void setResultFilter(AtomFilter atomFilter) {
        this.resultFilter = atomFilter;
    }

    public void setMetric(Metric metric) {
        this.metric = metric;
    }

    public double compare(Predicate predicate) {
        double d = 0.0d;
        int i = 0;
        Iterator<GroundAtom> filter = this.resultFilter.filter(Queries.getAllAtoms(this.result, predicate).iterator());
        while (filter.hasNext()) {
            GroundAtom next = filter.next();
            d += accumulate(this.baseline.getAtom(predicate, next.getArguments()).getValue() - next.getValue());
            i++;
        }
        return d / i;
    }

    private double accumulate(double d) {
        double d2;
        switch ($SWITCH_TABLE$edu$umd$cs$psl$evaluation$statistics$ContinuousPredictionComparator$Metric()[this.metric.ordinal()]) {
            case 1:
                d2 = d * d;
                break;
            case SliceRandOM.BURN_IN_DEFAULT /* 2 */:
                d2 = Math.abs(d);
                break;
            default:
                d2 = 0.0d;
                break;
        }
        return d2;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$edu$umd$cs$psl$evaluation$statistics$ContinuousPredictionComparator$Metric() {
        int[] iArr = $SWITCH_TABLE$edu$umd$cs$psl$evaluation$statistics$ContinuousPredictionComparator$Metric;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Metric.valuesCustom().length];
        try {
            iArr2[Metric.MAE.ordinal()] = 2;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Metric.MSE.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        $SWITCH_TABLE$edu$umd$cs$psl$evaluation$statistics$ContinuousPredictionComparator$Metric = iArr2;
        return iArr2;
    }
}
