package edu.umd.cs.psl.evaluation.statistics;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/statistics/MulticlassPredictionStatistics.class */
public class MulticlassPredictionStatistics implements PredictionStatistics {
    private final ConfusionMatrix cm;
    private int numEx;
    private int numCor;
    private int numErr;
    private int[] tp;
    private int[] fp;
    private int[] fn;
    private double acc;
    private double f1;
    private double[] precision;
    private double[] recall;
    private double[] f1class;

    public MulticlassPredictionStatistics(ConfusionMatrix confusionMatrix) {
        this.cm = confusionMatrix;
        computeStats();
    }

    public double getAccuracy() {
        return this.acc;
    }

    public double getF1() {
        return this.f1;
    }

    public double getF1(int i) {
        return this.f1class[i];
    }

    public double getPrecision(int i) {
        return this.precision[i];
    }

    public double getRecall(int i) {
        return this.recall[i];
    }

    public ConfusionMatrix getConfusionMatrix() {
        return this.cm.m41clone();
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionStatistics
    public double getError() {
        return this.numErr;
    }

    @Override // edu.umd.cs.psl.evaluation.statistics.PredictionStatistics
    public int getNumAtoms() {
        return this.numEx;
    }

    private void computeStats() {
        int numClasses = this.cm.getNumClasses();
        this.numCor = 0;
        this.numErr = 0;
        this.tp = new int[numClasses];
        this.fp = new int[numClasses];
        this.fn = new int[numClasses];
        for (int i = 0; i < numClasses; i++) {
            for (int i2 = 0; i2 < numClasses; i2++) {
                int i3 = this.cm.get(i, i2);
                if (i == i2) {
                    this.numCor += i3;
                    this.tp[i] = i3;
                } else {
                    this.numErr += i3;
                    int[] iArr = this.fp;
                    int i4 = i2;
                    iArr[i4] = iArr[i4] + i3;
                    int[] iArr2 = this.fn;
                    int i5 = i;
                    iArr2[i5] = iArr2[i5] + i3;
                }
            }
        }
        this.numEx = this.numCor + this.numErr;
        this.acc = this.numCor / this.numEx;
        this.f1 = 0.0d;
        this.precision = new double[numClasses];
        this.recall = new double[numClasses];
        this.f1class = new double[numClasses];
        for (int i6 = 0; i6 < numClasses; i6++) {
            this.precision[i6] = this.tp[i6] + this.fp[i6] == 0 ? 1.0d : this.tp[i6] / (this.tp[i6] + this.fp[i6]);
            this.recall[i6] = this.tp[i6] + this.fn[i6] == 0 ? 1.0d : this.tp[i6] / (this.tp[i6] + this.fn[i6]);
            this.f1class[i6] = this.precision[i6] + this.recall[i6] == 0.0d ? 0.0d : ((2.0d * this.precision[i6]) * this.recall[i6]) / (this.precision[i6] + this.recall[i6]);
            this.f1 += (this.f1class[i6] * (this.tp[i6] + this.fn[i6])) / this.numEx;
        }
    }
}
