package edu.umd.cs.psl.application.learning.weight.em;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.parameters.PositiveWeight;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.PredicateFactory;
import edu.umd.cs.psl.reasoner.admm.ADMMReasoner;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/em/PairedDualLearner.class */
public class PairedDualLearner extends ExpectationMaximization {
    private static final Logger log = LoggerFactory.getLogger(PairedDualLearner.class);
    public static final String CONFIG_PREFIX = "pairedduallearner";
    public static final String WARMUP_ROUNDS_KEY = "pairedduallearner.warmuprounds";
    public static final int WARMUP_ROUNDS_DEFAULT = 0;
    public static final String ADMM_STEPS_KEY = "pairedduallearner.admmsteps";
    public static final int ADMM_STEPS_DEFAULT = 1;
    double[] scalingFactor;
    double[] dualObservedIncompatibility;
    double[] dualExpectedIncompatibility;
    private final int warmupRounds;
    private final int admmIterations;
    Model model;
    String outputPrefix;
    Random random;

    public PairedDualLearner(Model model, Database database, Database database2, ConfigBundle configBundle) {
        super(model, database, database2, configBundle);
        this.random = new Random();
        this.scalingFactor = new double[this.kernels.size()];
        this.warmupRounds = configBundle.getInt(WARMUP_ROUNDS_KEY, 0);
        if (this.warmupRounds < 0) {
            throw new IllegalArgumentException("pairedduallearner.pairedduallearner.warmuprounds must be a nonnegative integer.");
        }
        this.admmIterations = configBundle.getInt(ADMM_STEPS_KEY, 1);
        if (this.admmIterations < 1) {
            throw new IllegalArgumentException("pairedduallearner.pairedduallearner.admmsteps must be a positive integer.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void initGroundModel() throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        super.initGroundModel();
        if (!(this.reasoner instanceof ADMMReasoner)) {
            throw new IllegalArgumentException("PairedDualLearning can only be used with ADMMReasoner.");
        }
    }

    public void setModel(Model model, String str) {
        this.model = model;
        this.outputPrefix = str;
    }

    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization
    protected void minimizeKLDivergence() {
        inferLatentVariables();
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeExpectedIncomp() {
        this.dualExpectedIncompatibility = new double[this.kernels.size() + this.immutableKernels.size()];
        this.reasoner.optimize();
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.reasoner;
        for (int i = 0; i < this.kernels.size(); i++) {
            for (GroundKernel groundKernel : this.reasoner.getGroundKernels(this.kernels.get(i))) {
                double[] dArr = this.dualExpectedIncompatibility;
                int i2 = i;
                dArr[i2] = dArr[i2] + aDMMReasoner.getDualIncompatibility(groundKernel);
            }
        }
        for (int i3 = 0; i3 < this.immutableKernels.size(); i3++) {
            for (GroundKernel groundKernel2 : this.reasoner.getGroundKernels(this.immutableKernels.get(i3))) {
                double[] dArr2 = this.dualExpectedIncompatibility;
                int size = this.kernels.size() + i3;
                dArr2[size] = dArr2[size] + aDMMReasoner.getDualIncompatibility(groundKernel2);
            }
        }
        return Arrays.copyOf(this.dualExpectedIncompatibility, this.kernels.size());
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double[] computeObservedIncomp() {
        this.numGroundings = new double[this.kernels.size()];
        this.dualObservedIncompatibility = new double[this.kernels.size() + this.immutableKernels.size()];
        setLabeledRandomVariables();
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.latentVariableReasoner;
        for (int i = 0; i < this.kernels.size(); i++) {
            for (GroundKernel groundKernel : this.latentVariableReasoner.getGroundKernels(this.kernels.get(i))) {
                double[] dArr = this.dualObservedIncompatibility;
                int i2 = i;
                dArr[i2] = dArr[i2] + aDMMReasoner.getDualIncompatibility(groundKernel);
                double[] dArr2 = this.numGroundings;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + 1.0d;
            }
        }
        for (int i4 = 0; i4 < this.immutableKernels.size(); i4++) {
            for (GroundKernel groundKernel2 : this.latentVariableReasoner.getGroundKernels(this.immutableKernels.get(i4))) {
                double[] dArr3 = this.dualObservedIncompatibility;
                int size = this.kernels.size() + i4;
                dArr3[size] = dArr3[size] + aDMMReasoner.getDualIncompatibility(groundKernel2);
            }
        }
        return Arrays.copyOf(this.dualObservedIncompatibility, this.kernels.size());
    }

    @Override // edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron
    protected double computeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.kernels.size(); i++) {
            d += this.kernels.get(i).getWeight().getWeight() * (this.dualObservedIncompatibility[i] - this.dualExpectedIncompatibility[i]);
        }
        for (int i2 = 0; i2 < this.immutableKernels.size(); i2++) {
            d += this.immutableKernels.get(i2).getWeight().getWeight() * (this.dualObservedIncompatibility[this.kernels.size() + i2] - this.dualExpectedIncompatibility[this.kernels.size() + i2]);
        }
        return d;
    }

    private void subgrad() {
        log.info("Starting optimization");
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < this.kernels.size(); i++) {
            dArr[i] = this.kernels.get(i).getWeight().getWeight();
        }
        double[] dArr2 = new double[this.kernels.size()];
        double[] dArr3 = new double[this.kernels.size()];
        for (int i2 = 0; i2 < this.kernels.size(); i2++) {
            dArr3[i2] = 1.0d;
        }
        double[] dArr4 = new double[this.kernels.size()];
        double d = 0.0d;
        int i3 = 0;
        while (true) {
            if (i3 >= this.iterations) {
                break;
            }
            d = getValueAndGradient(dArr3, dArr);
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.kernels.size(); i4++) {
                if (this.scheduleStepSize) {
                    dArr4[i4] = Math.pow(i3 + 1, 2.0d);
                } else {
                    dArr4[i4] = 1.0d;
                }
                d2 += Math.pow(dArr[i4] - Math.max(0.0d, dArr[i4] - dArr3[i4]), 2.0d);
                if (dArr4[i4] > 0.0d) {
                    double max = Math.max(-dArr[i4], (-(this.stepSize / Math.sqrt(dArr4[i4]))) * dArr3[i4]);
                    int i5 = i4;
                    dArr[i5] = dArr[i5] + max;
                    dArr3[i4] = max;
                    d3 += Math.pow(max, 2.0d);
                }
                dArr2[i4] = ((1.0d - (1.0d / (i3 + 1.0d))) * dArr2[i4]) + ((1.0d / (i3 + 1.0d)) * dArr[i4]);
            }
            if (this.storeWeights) {
                HashMap hashMap = new HashMap();
                for (int i6 = 0; i6 < this.kernels.size(); i6++) {
                    double d4 = this.averageSteps ? dArr2[i6] : dArr[i6];
                    if (d4 != 0.0d) {
                        hashMap.put(this.kernels.get(i6), Double.valueOf(d4));
                    }
                }
                this.storedWeights.add(hashMap);
            }
            double sqrt = Math.sqrt(d2);
            double sqrt2 = Math.sqrt(d3);
            DecimalFormat decimalFormat = new DecimalFormat("0.0000E00");
            if (i3 % 1 == 0) {
                log.info("Iter {}, obj: {}, norm grad: " + decimalFormat.format(sqrt) + ", change: " + decimalFormat.format(sqrt2), Integer.valueOf(i3), decimalFormat.format(d));
            }
            if (i3 % 50 == 0) {
                outputModel(i3);
            }
            if (sqrt2 < this.tolerance) {
                log.info("Change in w ({}) is less than tolerance. Finishing subgrad.", Double.valueOf(sqrt2));
                break;
            }
            i3++;
        }
        outputModel(this.iterations);
        log.info("Learning finished with final objective value {}", Double.valueOf(d));
        for (int i7 = 0; i7 < this.kernels.size(); i7++) {
            if (this.averageSteps) {
                dArr[i7] = dArr2[i7];
            }
            this.kernels.get(i7).setWeight(new PositiveWeight(dArr[i7]));
        }
    }

    private void outputModel(int i) {
        if (this.model == null) {
            return;
        }
        try {
            File file = new File(String.valueOf(this.outputPrefix) + "model" + i + ".txt");
            if (file.getParentFile() != null) {
                file.getParentFile().mkdirs();
            }
            FileWriter fileWriter = new FileWriter(file);
            BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
            Iterator<Predicate> it = PredicateFactory.getFactory().getPredicates().iterator();
            while (it.hasNext()) {
                bufferedWriter.write(String.valueOf(it.next().toString()) + "\n");
            }
            bufferedWriter.write(this.model.toString());
            bufferedWriter.close();
            fileWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.application.learning.weight.em.ExpectationMaximization, edu.umd.cs.psl.application.learning.weight.maxlikelihood.VotedPerceptron, edu.umd.cs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        int maxIter = ((ADMMReasoner) this.reasoner).getMaxIter();
        ((ADMMReasoner) this.reasoner).setMaxIter(this.admmIterations);
        ((ADMMReasoner) this.latentVariableReasoner).setMaxIter(this.admmIterations);
        if (this.augmentLoss) {
            addLossAugmentedKernels();
        }
        if (this.warmupRounds > 0) {
            log.info("Warming up optimizers with {} iterations each.", Integer.valueOf(this.warmupRounds * this.admmIterations));
            for (int i = 0; i < this.warmupRounds; i++) {
                this.reasoner.optimize();
                this.latentVariableReasoner.optimize();
            }
        }
        subgrad();
        if (this.augmentLoss) {
            removeLossAugmentedKernels();
        }
        ((ADMMReasoner) this.reasoner).setMaxIter(maxIter);
        ((ADMMReasoner) this.latentVariableReasoner).setMaxIter(maxIter);
    }

    private double getValueAndGradient(double[] dArr, double[] dArr2) {
        for (int i = 0; i < this.kernels.size(); i++) {
            if (dArr[i] != 0.0d) {
                this.kernels.get(i).setWeight(new PositiveWeight(dArr2[i]));
            }
        }
        minimizeKLDivergence();
        computeObservedIncomp();
        this.reasoner.changedGroundKernelWeights();
        computeExpectedIncomp();
        double d = 0.0d;
        for (int i2 = 0; i2 < this.kernels.size(); i2++) {
            d += dArr2[i2] * (this.dualObservedIncompatibility[i2] - this.dualExpectedIncompatibility[i2]);
        }
        for (int i3 = 0; i3 < this.immutableKernels.size(); i3++) {
            d += this.immutableKernels.get(i3).getWeight().getWeight() * (this.dualObservedIncompatibility[this.kernels.size() + i3] - this.dualExpectedIncompatibility[this.kernels.size() + i3]);
        }
        double lagrangianPenalty = ((ADMMReasoner) this.latentVariableReasoner).getLagrangianPenalty();
        double augmentedLagrangianPenalty = ((ADMMReasoner) this.latentVariableReasoner).getAugmentedLagrangianPenalty();
        double lagrangianPenalty2 = ((ADMMReasoner) this.reasoner).getLagrangianPenalty();
        double augmentedLagrangianPenalty2 = ((ADMMReasoner) this.reasoner).getAugmentedLagrangianPenalty();
        double d2 = d + (((lagrangianPenalty + augmentedLagrangianPenalty) - lagrangianPenalty2) - augmentedLagrangianPenalty2);
        for (int i4 = 0; i4 < this.kernels.size(); i4++) {
            log.debug("Incompatibility for kernel {}", this.kernels.get(i4));
            log.debug("Truth incompatbility {}, expected incompatibility {}", Double.valueOf(this.dualObservedIncompatibility[i4]), Double.valueOf(this.dualExpectedIncompatibility[i4]));
        }
        log.debug("E Penalty: {}, E Aug Penalty: {}, M Penalty: {}, M Aug Penalty: {}", new Double[]{Double.valueOf(lagrangianPenalty), Double.valueOf(augmentedLagrangianPenalty), Double.valueOf(lagrangianPenalty2), Double.valueOf(augmentedLagrangianPenalty2)});
        double computeRegularizer = computeRegularizer();
        if (dArr != null) {
            for (int i5 = 0; i5 < this.kernels.size(); i5++) {
                dArr[i5] = this.dualObservedIncompatibility[i5] - this.dualExpectedIncompatibility[i5];
                if (this.scaleGradient && this.numGroundings[i5] > 0.0d) {
                    int i6 = i5;
                    dArr[i6] = dArr[i6] / this.numGroundings[i5];
                }
                int i7 = i5;
                dArr[i7] = dArr[i7] + (this.l2Regularization * dArr2[i5]) + this.l1Regularization;
            }
        }
        return d2 + computeRegularizer;
    }
}
