package edu.umd.cs.psl.evaluation.debug;

import com.google.common.collect.HashBiMap;
import de.mathnbits.io.BasicUserInteraction;
import edu.umd.cs.psl.database.Database;
import edu.umd.cs.psl.database.DatabaseQuery;
import edu.umd.cs.psl.database.ResultList;
import edu.umd.cs.psl.model.argument.Term;
import edu.umd.cs.psl.model.argument.Variable;
import edu.umd.cs.psl.model.atom.Atom;
import edu.umd.cs.psl.model.atom.AtomCache;
import edu.umd.cs.psl.model.atom.GroundAtom;
import edu.umd.cs.psl.model.atom.QueryAtom;
import edu.umd.cs.psl.model.kernel.GroundCompatibilityKernel;
import edu.umd.cs.psl.model.kernel.GroundKernel;
import edu.umd.cs.psl.model.predicate.Predicate;
import edu.umd.cs.psl.model.predicate.PredicateFactory;
import edu.umd.cs.psl.model.predicate.StandardPredicate;
import edu.umd.cs.psl.util.database.Queries;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/umd/cs/psl/evaluation/debug/CmdDebugger.class */
public class CmdDebugger implements Debugger {
    private static final String quit = "q";
    private static final String queryPredicates = "all";
    private static final String queryAtom = "atom";
    private static final String help = "h";
    private static final DecimalFormat valueFormatter = new DecimalFormat("#.##");
    private final Database db;
    private final AtomCache cache;
    private final PredicateFactory predicateFactory = PredicateFactory.getFactory();
    private Map<Integer, Atom> atomHandles;

    public CmdDebugger(Database database) {
        this.db = database;
        this.cache = database.getAtomCache();
    }

    @Override // edu.umd.cs.psl.evaluation.debug.Debugger
    public void start() {
        println("Debugger started. Enter 'q' to exit and 'h' for help.");
        while (true) {
            String readCmd = readCmd();
            if (readCmd.equalsIgnoreCase(quit)) {
                println("Exiting Debugger.");
                return;
            }
            Integer num = null;
            try {
                num = Integer.valueOf(Integer.parseInt(readCmd));
            } catch (NumberFormatException e) {
            }
            if (num != null) {
                if (this.atomHandles == null) {
                    error("No atom handles defined in context!");
                } else if (this.atomHandles.containsKey(num)) {
                    printAtom(this.atomHandles.get(num));
                } else {
                    error("Atom handle [" + num + "] not defined in context!");
                }
            } else if (readCmd.toLowerCase().startsWith(queryPredicates)) {
                queryPredicate(readCmd.substring(queryPredicates.length()).trim());
            } else if (readCmd.toLowerCase().startsWith(queryAtom)) {
                queryAtom(readCmd.substring(queryAtom.length()).trim());
            } else if (readCmd.equalsIgnoreCase(help)) {
                printHelp();
            } else {
                error("Unrecognized command!");
            }
        }
    }

    private void printAtom(Atom atom) {
        println(AtomPrinter.atomDetails(atom));
        if (atom instanceof GroundAtom) {
            printGroundKernels(((GroundAtom) atom).getRegisteredGroundKernels());
        }
    }

    private String printGroundKernels(GroundKernel groundKernel) {
        String obj = groundKernel.toString();
        if (groundKernel instanceof GroundCompatibilityKernel) {
            obj = String.valueOf(obj) + " V=" + valueFormatter.format(((GroundCompatibilityKernel) groundKernel).getIncompatibility());
        }
        return obj;
    }

    private void printGroundKernels(Collection<GroundKernel> collection) {
        int i;
        HashBiMap create = HashBiMap.create();
        int i2 = 1;
        for (GroundKernel groundKernel : collection) {
            String printGroundKernels = printGroundKernels(groundKernel);
            StringBuilder sb = new StringBuilder();
            sb.append("--> Affected Atoms: ");
            for (GroundAtom groundAtom : groundKernel.getAtoms()) {
                if (create.containsValue(groundAtom)) {
                    i = ((Integer) create.inverse().get(groundAtom)).intValue();
                } else {
                    i = i2;
                    i2++;
                    create.put(Integer.valueOf(i), groundAtom);
                }
                printGroundKernels.replace(groundAtom.toString(), String.valueOf(groundAtom.toString()) + " [" + i + "]");
                sb.append(AtomPrinter.atomDetails(groundAtom)).append(" [" + i + "]").append(" , ");
            }
            println(printGroundKernels);
            println(sb.toString());
        }
        this.atomHandles = create;
    }

    private void queryPredicate(String str) {
        try {
            printAtoms(getConsideredAtoms(this.predicateFactory.getPredicate(str)));
        } catch (IllegalArgumentException e) {
            error(e.getMessage());
        }
    }

    private void printAtoms(List<GroundAtom> list) {
        if (list.isEmpty()) {
            println("No atoms found for query");
            return;
        }
        this.atomHandles = new HashMap(list.size());
        int i = 1;
        for (GroundAtom groundAtom : list) {
            println(String.valueOf(AtomPrinter.atomDetails(groundAtom)) + "  [" + i + "]");
            this.atomHandles.put(Integer.valueOf(i), groundAtom);
            i++;
        }
    }

    private void queryAtom(String str) {
        String[] split = str.split(" ");
        if (split.length < 2) {
            error("Invalid atom query!");
            return;
        }
        String str2 = split[0];
        Object[] objArr = new Object[split.length - 1];
        for (int i = 1; i < split.length; i++) {
            try {
                objArr[i - 1] = Integer.valueOf(Integer.parseInt(split[i]));
            } catch (NumberFormatException e) {
                objArr[i - 1] = split[i];
            }
        }
        try {
            Predicate predicate = this.predicateFactory.getPredicate(str2);
            printAtoms(getConsideredAtoms(predicate, Queries.convertArguments(this.db, predicate, objArr)));
        } catch (IllegalArgumentException e2) {
            error(e2.getMessage());
        }
    }

    private List<GroundAtom> getConsideredAtoms(Predicate predicate) {
        Term[] termArr = new Term[predicate.getArity()];
        for (int i = 0; i < termArr.length; i++) {
            termArr[i] = new Variable("Arg_" + i);
        }
        return getConsideredAtoms(predicate, termArr);
    }

    private List<GroundAtom> getConsideredAtoms(Predicate predicate, Term[] termArr) {
        if (!(predicate instanceof StandardPredicate)) {
            throw new IllegalArgumentException("Only StandardPredicates can be retrieved.");
        }
        ResultList executeQuery = this.db.executeQuery(new DatabaseQuery(new QueryAtom(predicate, termArr)));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < executeQuery.size(); i++) {
            GroundAtom cachedAtom = this.cache.getCachedAtom(new QueryAtom(predicate, executeQuery.get(i)));
            if (cachedAtom != null) {
                arrayList.add(cachedAtom);
            }
        }
        return arrayList;
    }

    private void printHelp() {
        println("'all <predicate>'\t-- display all atoms of that predicate");
        println("'atom <predicate> <atomArgument>+'\t-- display all atoms of that predicate with matching arguments. Note that the number of arguments must match the predicate arity. Use * as an argument wildcard.");
        println("'<number>'\t\t\t-- display atom with this number handle presented on the current screen");
    }

    private String readCmd() {
        return readCmd(null);
    }

    private String readCmd(String str) {
        System.out.print(">> ");
        if (str != null) {
            System.out.print(String.valueOf(str) + " ");
        }
        return BasicUserInteraction.readline().trim();
    }

    private void println(String str) {
        System.out.println(str);
    }

    private void error(String str) {
        println("ERROR: " + str);
    }
}
