package edu.umd.cs.psl.application.learning.weight.maxmargin;

import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.optimizer.conic.ConicProgramSolver;
import edu.umd.cs.psl.optimizer.conic.ConicProgramSolverFactory;
import edu.umd.cs.psl.optimizer.conic.ipm.HomogeneousIPMFactory;
import edu.umd.cs.psl.optimizer.conic.program.ConicProgram;
import edu.umd.cs.psl.optimizer.conic.program.LinearConstraint;
import edu.umd.cs.psl.optimizer.conic.program.SecondOrderCone;
import edu.umd.cs.psl.optimizer.conic.program.Variable;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/application/learning/weight/maxmargin/MinNormProgram.class */
public class MinNormProgram {
    public static final String CONFIG_PREFIX = "minnormprog";
    public static final String CPS_KEY = "minnormprog.conicprogramsolver";
    public static final ConicProgramSolverFactory CPS_DEFAULT = new HomogeneousIPMFactory();
    private int size;
    private ConicProgramSolver solver;
    private Variable[] variables;
    private SecondOrderCone quadraticCone;
    private Variable squaredNorm;
    Logger log = LoggerFactory.getLogger(MinNormProgram.class);
    private ConicProgram program = new ConicProgram();

    public MinNormProgram(int i, boolean z, ConfigBundle configBundle) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        this.size = i;
        this.solver = ((ConicProgramSolverFactory) configBundle.getFactory(CPS_KEY, CPS_DEFAULT)).getConicProgramSolver(configBundle);
        this.variables = new Variable[i];
        for (int i2 = 0; i2 < i; i2++) {
            if (z) {
                this.variables[i2] = this.program.createNonNegativeOrthantCone().getVariable();
            } else {
                this.variables[i2] = this.program.createSecondOrderCone(2).getNthVariable();
            }
        }
        this.squaredNorm = this.program.createNonNegativeOrthantCone().getVariable();
        this.squaredNorm.setObjectiveCoefficient(Double.valueOf(0.5d));
    }

    public void addInequalityConstraint(double[] dArr, double d) {
        Variable variable = this.program.createNonNegativeOrthantCone().getVariable();
        LinearConstraint createConstraint = this.program.createConstraint();
        for (int i = 0; i < dArr.length; i++) {
            createConstraint.setVariable(this.variables[i], Double.valueOf(dArr[i]));
        }
        createConstraint.setConstrainedValue(Double.valueOf(d));
        createConstraint.setVariable(variable, Double.valueOf(1.0d));
    }

    public void solve() {
        normalizeCoefficients();
        this.solver.setConicProgram(this.program);
        this.solver.solve();
    }

    public double[] getSolution() {
        double[] dArr = new double[this.size];
        for (int i = 0; i < this.size; i++) {
            dArr[i] = this.variables[i].getValue().doubleValue();
        }
        return dArr;
    }

    public void setLinearCoefficients(double[] dArr) {
        for (int i = 0; i < this.size; i++) {
            this.variables[i].setObjectiveCoefficient(Double.valueOf(dArr[i]));
        }
    }

    public void setQuadraticTerm(double[] dArr, double[] dArr2) {
        int i = 0;
        for (double d : dArr) {
            if (d > 0.0d) {
                i++;
            } else if (d < 0.0d) {
                throw new IllegalArgumentException("Weights must be non-negative.");
            }
        }
        if (this.quadraticCone != null) {
            Iterator<Variable> it = this.quadraticCone.getVariables().iterator();
            while (it.hasNext()) {
                Iterator<LinearConstraint> it2 = it.next().getLinearConstraints().iterator();
                while (it2.hasNext()) {
                    it2.next().delete();
                }
            }
            this.quadraticCone.delete();
        }
        this.quadraticCone = this.program.createSecondOrderCone(i + 2);
        Iterator<Variable> it3 = this.quadraticCone.getInnerVariables().iterator();
        Variable next = it3.next();
        LinearConstraint createConstraint = this.program.createConstraint();
        createConstraint.setVariable(next, Double.valueOf(1.0d));
        createConstraint.setVariable(this.squaredNorm, Double.valueOf(0.5d));
        createConstraint.setConstrainedValue(Double.valueOf(0.5d));
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] != 0.0d) {
                Variable next2 = it3.next();
                LinearConstraint createConstraint2 = this.program.createConstraint();
                createConstraint2.setVariable(next2, Double.valueOf(-1.0d));
                createConstraint2.setVariable(this.variables[i2], Double.valueOf(dArr[i2]));
                createConstraint2.setConstrainedValue(Double.valueOf(dArr[i2] * dArr2[i2]));
            }
        }
        Variable nthVariable = this.quadraticCone.getNthVariable();
        LinearConstraint createConstraint3 = this.program.createConstraint();
        createConstraint3.setVariable(nthVariable, Double.valueOf(1.0d));
        createConstraint3.setVariable(this.squaredNorm, Double.valueOf(-0.5d));
        createConstraint3.setConstrainedValue(Double.valueOf(0.5d));
    }

    public void close() {
        this.program = null;
        this.solver = null;
        for (int i = 0; i < this.size; i++) {
            this.variables[i] = null;
        }
        this.quadraticCone = null;
    }

    private void normalizeCoefficients() {
        double d = 0.0d;
        for (int i = 0; i < this.size; i++) {
            d = Math.max(d, this.variables[i].getObjectiveCoefficient().doubleValue());
        }
        double max = Math.max(d, this.squaredNorm.getObjectiveCoefficient().doubleValue());
        this.squaredNorm.setObjectiveCoefficient(Double.valueOf(this.squaredNorm.getObjectiveCoefficient().doubleValue() / max));
        for (int i2 = 0; i2 < this.size; i2++) {
            this.variables[i2].setObjectiveCoefficient(Double.valueOf(this.variables[i2].getObjectiveCoefficient().doubleValue() / max));
        }
    }
}
