package com.googlecode.rockit.app.learner;

import com.googlecode.rockit.app.Parameters;
import com.googlecode.rockit.app.grounder.StandardGrounder;
import com.googlecode.rockit.app.solver.StandardSolver;
import com.googlecode.rockit.conn.sql.MySQLConnector;
import com.googlecode.rockit.conn.sql.SQLQueryGenerator;
import com.googlecode.rockit.exception.ParseException;
import com.googlecode.rockit.exception.SolveException;
import com.googlecode.rockit.javaAPI.Model;
import com.googlecode.rockit.javaAPI.formulas.FormulaAbstract;
import com.googlecode.rockit.javaAPI.formulas.FormulaSoft;
import com.googlecode.rockit.parser.SyntaxReader;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import org.antlr.runtime.RecognitionException;

/* loaded from: input_file:com/googlecode/rockit/app/learner/VotedPerceptronLearner.class */
public class VotedPerceptronLearner {
    private Model model;
    private ArrayList<FormulaForLearning> formulas = new ArrayList<>();
    private ArrayList<String> results = new ArrayList<>();

    public VotedPerceptronLearner(Model model) throws SQLException {
        this.model = model;
    }

    private Model learnIteration(String str, MySQLConnector mySQLConnector) throws SQLException, ParseException, SolveException, IOException, RecognitionException {
        this.model = new SyntaxReader().getGroundValuesForLearning(str, this.model);
        mySQLConnector.deleteAll();
        StandardGrounder standardGrounder = new StandardGrounder(this.model, mySQLConnector);
        standardGrounder.setGroundHiddenPredicates(true);
        standardGrounder.ground();
        if (this.formulas.isEmpty()) {
            Iterator<FormulaAbstract> it = this.model.getFormulas().iterator();
            while (it.hasNext()) {
                FormulaAbstract next = it.next();
                if (next instanceof FormulaSoft) {
                    FormulaSoft formulaSoft = (FormulaSoft) next;
                    this.formulas.add(new FormulaForLearning(formulaSoft, numberOfTrueGroundings(formulaSoft, mySQLConnector)));
                }
            }
        }
        StandardSolver standardSolver = new StandardSolver(this.model, mySQLConnector);
        this.results = standardSolver.runCuttingPlaneInference();
        standardSolver.closeILPConnector();
        if (Parameters.DEBUG_OUTPUT) {
            System.out.println("-- new weights --");
        }
        Iterator<FormulaForLearning> it2 = this.formulas.iterator();
        while (it2.hasNext()) {
            FormulaForLearning next2 = it2.next();
            long numberOfTrueGroundings = numberOfTrueGroundings(next2.getFormula(), mySQLConnector);
            double nextWeightVP = nextWeightVP(next2.getFormula().getWeight().doubleValue(), Parameters.LEARNING_RATE, numberOfTrueGroundings, next2.getExpectedNumberOfTrueGroundings());
            next2.getFormula().setWeight(Double.valueOf(nextWeightVP));
            next2.addWeightForAverage(nextWeightVP);
            if (Parameters.DEBUG_OUTPUT) {
                System.out.println(next2.getFormula());
            }
        }
        return this.model;
    }

    private double nextWeightVP(double d, double d2, long j, long j2) {
        double d3 = d + (d2 * (j2 - j));
        if (Parameters.DEBUG_OUTPUT) {
            System.out.println(String.valueOf(d3) + " = " + d + " + " + d2 + " * (" + j2 + " - " + j + ")");
        }
        return d3;
    }

    private double nextWeightVPPW(double d, double d2, long j, long j2) {
        if (j != 0) {
            if (Parameters.DEBUG_OUTPUT) {
                System.out.println("previous learn rate " + d2);
            }
            d2 /= j;
            if (Parameters.DEBUG_OUTPUT) {
                System.out.println("adapted learn rate =  " + d2);
            }
        }
        double d3 = d + (d2 * (j2 - j));
        if (Parameters.DEBUG_OUTPUT) {
            System.out.println(String.valueOf(d3) + " = " + d + " + " + d2 + " * (" + j2 + " - " + j + ")");
        }
        return d3;
    }

    public Model learn(ArrayList<String> arrayList, int i) throws ParseException, IOException, RecognitionException, SolveException, SQLException {
        if (i < 0) {
            i = 10;
        }
        MySQLConnector mySQLConnector = new MySQLConnector();
        for (int i2 = 0; i2 < i; i2++) {
            System.out.println("======== START LEARNING ITERATION " + (i2 + 1) + " ==============");
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                this.model = learnIteration(arrayList.get(i3), mySQLConnector);
            }
        }
        mySQLConnector.close();
        Iterator<FormulaForLearning> it = this.formulas.iterator();
        while (it.hasNext()) {
            FormulaForLearning next = it.next();
            next.getFormula().setWeight(Double.valueOf(next.returnAverage(i * arrayList.size())));
        }
        return this.model;
    }

    private long numberOfTrueGroundings(FormulaSoft formulaSoft, MySQLConnector mySQLConnector) throws SolveException, SQLException {
        return mySQLConnector.executeLongQuery(String.valueOf("SELECT count(*) ") + eliminateFrom(SQLQueryGenerator.getSQLStatementWithoutSelect(formulaSoft, false, false)));
    }

    private long numberOfFalseGroundings(FormulaSoft formulaSoft, MySQLConnector mySQLConnector) throws SolveException, SQLException {
        return mySQLConnector.executeLongQuery(String.valueOf("SELECT count(*) ") + eliminateFrom(SQLQueryGenerator.getSQLStatementWithoutSelect(formulaSoft, true, false)));
    }

    private long numberOfAllGroundings(FormulaSoft formulaSoft, MySQLConnector mySQLConnector) throws SolveException, SQLException, ParseException {
        return mySQLConnector.executeLongQuery(String.valueOf("SELECT count(*) ") + eliminateFrom(SQLQueryGenerator.getSQLStatementWithoutSelectWithoutCPI(formulaSoft, false, false)));
    }

    private String eliminateFrom(String str) {
        return str.substring(str.lastIndexOf("FROM"), str.indexOf(")"));
    }
}
