package edu.umd.cs.psl.optimizer.conic.ipm.solver;

import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tdouble.algo.solver.DefaultDoubleIterationMonitor;
import cern.colt.matrix.tdouble.algo.solver.DoubleCG;
import cern.colt.matrix.tdouble.algo.solver.DoubleIterationMonitor;
import cern.colt.matrix.tdouble.algo.solver.DoubleIterationReporter;
import cern.colt.matrix.tdouble.algo.solver.IterativeSolverDoubleNotConvergedException;
import cern.colt.matrix.tdouble.algo.solver.preconditioner.DoublePreconditioner;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.tdouble.impl.SparseCCDoubleMatrix2D;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.optimizer.conic.ipm.solver.preconditioner.IdentityPreconditionerFactory;
import edu.umd.cs.psl.optimizer.conic.ipm.solver.preconditioner.PreconditionerFactory;
import edu.umd.cs.psl.optimizer.conic.program.ConicProgram;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/optimizer/conic/ipm/solver/ConjugateGradient.class */
public class ConjugateGradient implements NormalSystemSolver {
    public static final String CONFIG_PREFIX = "cgsolver";
    public static final String CG_MAX_ITER_KEY = "cgsolver.maxcgiter";
    public static final int CG_MAX_ITER_DEFAULT = 1000000;
    public static final String CG_REL_TOL_KEY = "cgsolver.cgreltol";
    public static final double CG_REL_TOL_DEFAULT = 1.0E-9d;
    public static final String CG_ABS_TOL_KEY = "cgsolver.cgabstol";
    public static final double CG_ABS_TOL_DEFAULT = 1.0E-49d;
    public static final String CG_DIV_TOL_KEY = "cgsolver.cgdivtol";
    public static final double CG_DIV_TOL_DEFAULT = 1000000.0d;
    public static final String PRECONDITIONER_KEY = "cgsolver.preconditioner";
    private final int maxIter;
    private final double relTol;
    private final double absTol;
    private final double divTol;
    private final PreconditionerFactory preconditionerFactory;
    private DoubleCG cg;
    private DoublePreconditioner preconditioner;
    private DoubleIterationMonitor monitor;
    private DoubleMatrix2D A;
    private DoubleMatrix1D x;
    private static final Logger log = LoggerFactory.getLogger(ConjugateGradient.class);
    public static final PreconditionerFactory PRECONDITIONER_DEFAULT = new IdentityPreconditionerFactory();

    public ConjugateGradient(ConfigBundle configBundle) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        this.maxIter = configBundle.getInt(CG_MAX_ITER_KEY, 1000000);
        this.relTol = configBundle.getDouble(CG_REL_TOL_KEY, 1.0E-9d);
        this.absTol = configBundle.getDouble(CG_ABS_TOL_KEY, 1.0E-49d);
        this.divTol = configBundle.getDouble(CG_DIV_TOL_KEY, 1000000.0d);
        this.preconditionerFactory = (PreconditionerFactory) configBundle.getFactory(PRECONDITIONER_KEY, PRECONDITIONER_DEFAULT);
        this.monitor = new DefaultDoubleIterationMonitor(this.maxIter, this.relTol, this.absTol, this.divTol);
        this.monitor.setIterationReporter(new DoubleIterationReporter() { // from class: edu.umd.cs.psl.optimizer.conic.ipm.solver.ConjugateGradient.1
            public void monitor(double d, DoubleMatrix1D doubleMatrix1D, int i) {
                monitor(d, i);
            }

            public void monitor(double d, int i) {
                if (i % 50 == 0) {
                    ConjugateGradient.log.trace("Res. at itr {}: {}", Integer.valueOf(i), Double.valueOf(d));
                }
            }
        });
    }

    @Override // edu.umd.cs.psl.optimizer.conic.ipm.solver.NormalSystemSolver
    public void setConicProgram(ConicProgram conicProgram) {
        this.x = new DenseDoubleMatrix1D(conicProgram.getA().rows());
        this.cg = new DoubleCG(this.x);
        this.cg.setIterationMonitor(this.monitor);
        this.preconditioner = this.preconditionerFactory.getPreconditioner(conicProgram);
        this.cg.setPreconditioner(this.preconditioner);
    }

    @Override // edu.umd.cs.psl.optimizer.conic.ipm.solver.NormalSystemSolver
    public void setA(SparseCCDoubleMatrix2D sparseCCDoubleMatrix2D) {
        this.A = sparseCCDoubleMatrix2D;
        this.preconditioner.setMatrix(sparseCCDoubleMatrix2D);
    }

    @Override // edu.umd.cs.psl.optimizer.conic.ipm.solver.NormalSystemSolver
    public void solve(DoubleMatrix1D doubleMatrix1D) {
        this.x.assign(0.0d);
        try {
            this.cg.solve(this.A, doubleMatrix1D, this.x);
            log.debug("Solved in {} iterations.", Integer.valueOf(this.monitor.iterations()));
            doubleMatrix1D.assign(this.x);
        } catch (IterativeSolverDoubleNotConvergedException e) {
            throw new IllegalArgumentException((Throwable) e);
        }
    }
}
