package edu.umd.cs.psl.parser;

import edu.umd.cs.psl.database.DataStore;
import edu.umd.cs.psl.model.Model;
import edu.umd.cs.psl.model.argument.ArgumentType;
import edu.umd.cs.psl.model.argument.IntegerAttribute;
import edu.umd.cs.psl.model.argument.StringAttribute;
import edu.umd.cs.psl.model.argument.Term;
import edu.umd.cs.psl.model.argument.Variable;
import edu.umd.cs.psl.model.atom.QueryAtom;
import edu.umd.cs.psl.model.formula.Conjunction;
import edu.umd.cs.psl.model.formula.Disjunction;
import edu.umd.cs.psl.model.formula.Formula;
import edu.umd.cs.psl.model.formula.Negation;
import edu.umd.cs.psl.model.formula.Rule;
import edu.umd.cs.psl.model.kernel.predicateconstraint.DomainRangeConstraintKernel;
import edu.umd.cs.psl.model.kernel.predicateconstraint.DomainRangeConstraintType;
import edu.umd.cs.psl.model.kernel.predicateconstraint.SymmetryConstraintKernel;
import edu.umd.cs.psl.model.kernel.rule.CompatibilityRuleKernel;
import edu.umd.cs.psl.model.kernel.rule.ConstraintRuleKernel;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.PredicateFactory;
import edu.umd.cs.psl.model.predicate.SpecialPredicate;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.parser.PSLParser;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.antlr.v4.runtime.ANTLRFileStream;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/parser/PSLModelLoader.class */
public class PSLModelLoader extends PSLBaseVisitor<Formula> {
    private static final Logger log = LoggerFactory.getLogger(PSLModelLoader.class);
    private Model model;
    private DataStore ds;
    private Map<String, Variable> varMap = new HashMap();
    private PredicateFactory pf = PredicateFactory.getFactory();

    public PSLModelLoader(Model model, DataStore dataStore) {
        this.model = model;
        this.ds = dataStore;
    }

    public Model getModel() {
        return this.model;
    }

    @Override // edu.umd.cs.psl.parser.PSLBaseVisitor, edu.umd.cs.psl.parser.PSLVisitor
    public Formula visitPredicateDefinition(PSLParser.PredicateDefinitionContext predicateDefinitionContext) {
        String text = predicateDefinitionContext.predicate().getText();
        ArgumentType[] argumentTypeArr = new ArgumentType[predicateDefinitionContext.argumentType().size()];
        int i = 0;
        for (PSLParser.ArgumentTypeContext argumentTypeContext : predicateDefinitionContext.argumentType()) {
            if (argumentTypeContext.getText().equals("UniqueID")) {
                argumentTypeArr[i] = ArgumentType.UniqueID;
            } else if (argumentTypeContext.getText().equals("String")) {
                argumentTypeArr[i] = ArgumentType.String;
            } else if (argumentTypeContext.getText().equals("Double")) {
                argumentTypeArr[i] = ArgumentType.Double;
            } else {
                if (!argumentTypeContext.getText().equals("Integer")) {
                    throw new UnsupportedOperationException("Unknown argument type " + argumentTypeContext.getText());
                }
                argumentTypeArr[i] = ArgumentType.Integer;
            }
            i++;
        }
        this.pf.createStandardPredicate(text, argumentTypeArr);
        this.ds.registerPredicate(this.pf.getPredicate(text));
        log.debug("Created predicate " + this.pf.getPredicate(text));
        return null;
    }

    @Override // edu.umd.cs.psl.parser.PSLBaseVisitor, edu.umd.cs.psl.parser.PSLVisitor
    public Formula visitExpression(PSLParser.ExpressionContext expressionContext) {
        if (expressionContext.atom() == null) {
            return expressionContext.AND() != null ? new Conjunction(new Formula[]{(Formula) visit(expressionContext.expression(0)), (Formula) visit(expressionContext.expression(1))}) : expressionContext.OR() != null ? new Disjunction(new Formula[]{(Formula) visit(expressionContext.expression(0)), (Formula) visit(expressionContext.expression(1))}) : expressionContext.THEN() != null ? new Rule((Formula) visit(expressionContext.expression(0)), (Formula) visit(expressionContext.expression(1))) : expressionContext.IMPLIEDBY() != null ? new Rule((Formula) visit(expressionContext.expression(1)), (Formula) visit(expressionContext.expression(0))) : expressionContext.NOT() != null ? new Negation((Formula) visit(expressionContext.expression(0))) : expressionContext.SYMMETRIC() != null ? new QueryAtom(SpecialPredicate.NonSymmetric, new Term[]{getVariable(expressionContext.argument(0).getText()), getVariable(expressionContext.argument(1).getText())}) : expressionContext.NOTEQUAL() != null ? new QueryAtom(SpecialPredicate.NotEqual, new Term[]{getVariable(expressionContext.argument(0).getText()), getVariable(expressionContext.argument(1).getText())}) : (Formula) visit(expressionContext.expression(0));
        }
        Predicate predicate = this.pf.getPredicate(expressionContext.atom().predicate().getText());
        Term[] termArr = new Term[expressionContext.atom().argument().size()];
        int i = 0;
        for (PSLParser.ArgumentContext argumentContext : expressionContext.atom().argument()) {
            if (argumentContext.variable() != null) {
                termArr[i] = getVariable(argumentContext.getText());
            } else {
                PSLParser.ConstantContext constant = argumentContext.constant();
                if (predicate.getArgumentType(i) == ArgumentType.UniqueID) {
                    termArr[i] = this.ds.getUniqueID(constant.getText());
                } else if (constant.strConstant() != null) {
                    String text = constant.getText();
                    termArr[i] = new StringAttribute(text.substring(1, text.length() - 1));
                } else {
                    termArr[i] = new IntegerAttribute(Integer.valueOf(Integer.parseInt(constant.getText())));
                }
            }
            i++;
        }
        return new QueryAtom(predicate, termArr);
    }

    private Variable getVariable(String str) {
        Variable variable = this.varMap.get(str);
        if (variable != null) {
            return variable;
        }
        Variable variable2 = new Variable(str);
        this.varMap.put(str, variable2);
        return variable2;
    }

    @Override // edu.umd.cs.psl.parser.PSLBaseVisitor, edu.umd.cs.psl.parser.PSLVisitor
    public Formula visitKernel(PSLParser.KernelContext kernelContext) {
        PSLParser.WeightContext weight = kernelContext.weight();
        Formula formula = (Formula) visit(kernelContext.expression());
        if (weight.CONSTRAINT() == null) {
            this.model.addKernel(new CompatibilityRuleKernel(formula, Double.parseDouble(kernelContext.weight().NUMBER().getText()), kernelContext.SQUARED() != null));
            return null;
        }
        this.model.addKernel(new ConstraintRuleKernel(formula));
        return null;
    }

    @Override // edu.umd.cs.psl.parser.PSLBaseVisitor, edu.umd.cs.psl.parser.PSLVisitor
    public Formula visitConstraint(PSLParser.ConstraintContext constraintContext) {
        StandardPredicate predicate = this.pf.getPredicate(constraintContext.predicate().getText());
        DomainRangeConstraintType domainRangeConstraintType = null;
        if (constraintContext.constraintType().getText().equals("Functional")) {
            domainRangeConstraintType = DomainRangeConstraintType.Functional;
        } else if (constraintContext.constraintType().getText().equals("InverseFunctional")) {
            domainRangeConstraintType = DomainRangeConstraintType.InverseFunctional;
        } else if (constraintContext.constraintType().getText().equals("PartialFunctional")) {
            domainRangeConstraintType = DomainRangeConstraintType.PartialFunctional;
        } else if (constraintContext.constraintType().getText().equals("PartialInverseFunctional")) {
            domainRangeConstraintType = DomainRangeConstraintType.PartialInverseFunctional;
        } else if (constraintContext.constraintType().getText().equals("Symmetric")) {
            this.model.addKernel(new SymmetryConstraintKernel(predicate));
        }
        this.model.addKernel(new DomainRangeConstraintKernel(predicate, domainRangeConstraintType));
        return null;
    }

    public static Model loadModel(String str, DataStore dataStore) {
        Model model = new Model();
        try {
            ParseTree program = new PSLParser(new CommonTokenStream(new PSLLexer(new ANTLRFileStream(str)))).program();
            PSLModelLoader pSLModelLoader = new PSLModelLoader(model, dataStore);
            pSLModelLoader.visit(program);
            log.debug(pSLModelLoader.getModel().toString());
        } catch (IOException e) {
            e.printStackTrace();
        }
        return model;
    }

    public static void outputModel(String str, Model model) {
        try {
            File file = new File(str);
            if (file.getParentFile() != null) {
                file.getParentFile().mkdirs();
            }
            FileWriter fileWriter = new FileWriter(file);
            BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
            Iterator it = PredicateFactory.getFactory().getPredicates().iterator();
            while (it.hasNext()) {
                bufferedWriter.write(String.valueOf(((Predicate) it.next()).toString()) + "\n");
            }
            bufferedWriter.write(model.toString());
            bufferedWriter.close();
            fileWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
