package edu.umd.cs.psl.evaluation.statistics;

import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import edu.umd.cs.psl.model.atom.GroundAtom;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/DiscretePredictionStatistics.class */
public class DiscretePredictionStatistics implements PredictionStatistics {
    private final int tp;
    private final int fp;
    private final int fn;
    private final int tn;
    private final double threshold;
    private final Map<GroundAtom, Double> errors;
    private final Set<GroundAtom> correctAtoms;

    /* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/DiscretePredictionStatistics$BinaryClass.class */
    public enum BinaryClass {
        NEGATIVE,
        POSITIVE;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static BinaryClass[] valuesCustom() {
            BinaryClass[] valuesCustom = values();
            int length = valuesCustom.length;
            BinaryClass[] binaryClassArr = new BinaryClass[length];
            System.arraycopy(valuesCustom, 0, binaryClassArr, 0, length);
            return binaryClassArr;
        }
    }

    public DiscretePredictionStatistics(int i, int i2, int i3, int i4, double d, Map<GroundAtom, Double> map, Set<GroundAtom> set) {
        this.tp = i;
        this.fp = i2;
        this.tn = i3;
        this.fn = i4;
        this.threshold = d;
        this.errors = map;
        this.correctAtoms = set;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public Map<GroundAtom, Double> getErrors() {
        return Collections.unmodifiableMap(this.errors);
    }

    public Set<GroundAtom> getCorrectAtoms() {
        return Collections.unmodifiableSet(this.correctAtoms);
    }

    public Iterable<Map.Entry<GroundAtom, Double>> getFalsePositives() {
        return Iterables.filter(this.errors.entrySet(), new Predicate<Map.Entry<GroundAtom, Double>>() { // from class: edu.umd.cs.psl.evaluation.statistics.DiscretePredictionStatistics.1
            public boolean apply(Map.Entry<GroundAtom, Double> entry) {
                return entry.getValue().doubleValue() > 0.0d;
            }
        });
    }

    public Iterable<Map.Entry<GroundAtom, Double>> getFalseNegatives() {
        return Iterables.filter(this.errors.entrySet(), new Predicate<Map.Entry<GroundAtom, Double>>() { // from class: edu.umd.cs.psl.evaluation.statistics.DiscretePredictionStatistics.2
            public boolean apply(Map.Entry<GroundAtom, Double> entry) {
                return entry.getValue().doubleValue() < 0.0d;
            }
        });
    }

    public double getPrecision(BinaryClass binaryClass) {
        if (binaryClass == BinaryClass.NEGATIVE) {
            double d = this.tn + this.fn;
            if (d == 0.0d) {
                return 1.0d;
            }
            return this.tn / d;
        }
        double d2 = this.tp + this.fp;
        if (d2 == 0.0d) {
            return 1.0d;
        }
        return this.tp / d2;
    }

    public double getRecall(BinaryClass binaryClass) {
        if (binaryClass == BinaryClass.NEGATIVE) {
            if (this.tn + this.fp == 0.0d) {
                return 1.0d;
            }
            return this.tn / (this.tn + this.fp);
        }
        double d = this.tp + this.fn;
        if (d == 0.0d) {
            return 1.0d;
        }
        return this.tp / d;
    }

    public double getF1(BinaryClass binaryClass) {
        double precision = getPrecision(binaryClass);
        double recall = getRecall(binaryClass);
        double d = precision + recall;
        if (d == 0.0d) {
            return 0.0d;
        }
        return (2.0d * (precision * recall)) / d;
    }

    public double getAccuracy() {
        if (getNumAtoms() == 0) {
            return 0.0d;
        }
        return (this.tp + this.tn) / getNumAtoms();
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionStatistics
    public double getError() {
        return this.fp + this.fn;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionStatistics
    public int getNumAtoms() {
        return this.tp + this.fp + this.tn + this.fn;
    }
}
