package edu.umd.cs.psl.optimizer.conic.partition;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import edu.umd.cs.psl.config.ConfigBundle;
import edu.umd.cs.psl.optimizer.conic.program.Cone;
import edu.umd.cs.psl.optimizer.conic.program.ConeType;
import edu.umd.cs.psl.optimizer.conic.program.ConicProgram;
import edu.umd.cs.psl.optimizer.conic.program.ConicProgramEvent;
import edu.umd.cs.psl.optimizer.conic.program.ConicProgramListener;
import edu.umd.cs.psl.optimizer.conic.program.Entity;
import edu.umd.cs.psl.optimizer.conic.program.LinearConstraint;
import edu.umd.cs.psl.optimizer.conic.program.NonNegativeOrthantCone;
import edu.umd.cs.psl.optimizer.conic.program.SecondOrderCone;
import edu.umd.cs.psl.optimizer.conic.program.Variable;
import edu.umd.cs.psl.util.graph.Node;
import edu.umd.cs.psl.util.graph.Relationship;
import edu.umd.cs.psl.util.graph.memory.MemoryGraph;
import edu.umd.cs.psl.util.graph.partition.hierarchical.HyperPartitioning;
import edu.umd.cs.psl.util.graph.weight.RelationshipWeighter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:edu/umd/cs/psl/optimizer/conic/partition/HierarchicalPartitioner.class */
public abstract class HierarchicalPartitioner extends AbstractCompletePartitioner implements ConicProgramListener {
    protected BiMap<Cone, Node> coneMap;
    protected BiMap<LinearConstraint, Node> lcMap;
    protected Set<LinearConstraint> alwaysCutConstraints;
    protected Set<LinearConstraint> restrictedConstraints;
    protected int p;
    private static final String LC_REL = "lcRel";
    private static final Logger log = LoggerFactory.getLogger(HierarchicalPartitioner.class);
    private static final ArrayList<ConeType> supportedCones = new ArrayList<>(2);

    static {
        supportedCones.add(ConeType.NonNegativeOrthantCone);
        supportedCones.add(ConeType.SecondOrderCone);
    }

    public HierarchicalPartitioner(ConfigBundle configBundle) {
    }

    @Override // edu.umd.cs.psl.optimizer.conic.partition.AbstractCompletePartitioner, edu.umd.cs.psl.optimizer.conic.partition.CompletePartitioner
    public void setConicProgram(ConicProgram conicProgram) {
        if (this.program != null) {
            this.program.unregisterForConicProgramEvents(this);
        }
        super.setConicProgram(conicProgram);
        this.program.registerForConicProgramEvents(this);
    }

    @Override // edu.umd.cs.psl.optimizer.conic.partition.CompletePartitioner
    public boolean supportsConeTypes(Collection<ConeType> collection) {
        return supportedCones.containsAll(collection);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.umd.cs.psl.optimizer.conic.partition.AbstractCompletePartitioner
    public void doPartition() {
        this.partitions.clear();
        int ceil = (int) Math.ceil(this.program.getNumLinearConstraints() / 5000.0d);
        List<List<Node>> list = null;
        this.alwaysCutConstraints = new HashSet();
        this.restrictedConstraints = new HashSet();
        HyperPartitioning hyperPartitioning = new HyperPartitioning();
        hyperPartitioning.setSize(ceil);
        this.p = 0;
        while (true) {
            boolean z = false;
            try {
                MemoryGraph memoryGraph = new MemoryGraph();
                memoryGraph.createRelationshipType(LC_REL);
                this.coneMap = HashBiMap.create();
                this.lcMap = HashBiMap.create();
                Iterator<Cone> it = this.program.getCones().iterator();
                while (it.hasNext()) {
                    this.coneMap.put(it.next(), memoryGraph.createNode());
                }
                HashSet hashSet = new HashSet();
                for (LinearConstraint linearConstraint : this.program.getConstraints()) {
                    Node createNode = memoryGraph.createNode();
                    this.lcMap.put(linearConstraint, createNode);
                    Iterator<Variable> it2 = linearConstraint.getVariables().keySet().iterator();
                    while (it2.hasNext()) {
                        hashSet.add(it2.next().getCone());
                    }
                    Iterator it3 = hashSet.iterator();
                    while (it3.hasNext()) {
                        createNode.createRelationship(LC_REL, (Node) this.coneMap.get((Cone) it3.next()));
                    }
                    hashSet.clear();
                }
                list = hyperPartitioning.partition(memoryGraph, memoryGraph.getNodeSnapshot(), new RelationshipWeighter() { // from class: edu.umd.cs.psl.optimizer.conic.partition.HierarchicalPartitioner.1
                    @Override // edu.umd.cs.psl.util.graph.weight.RelationshipWeighter
                    public double getWeight(Relationship relationship) {
                        if (!relationship.getRelationshipType().equals(HierarchicalPartitioner.LC_REL)) {
                            return Double.POSITIVE_INFINITY;
                        }
                        return HierarchicalPartitioner.this.getWeight((LinearConstraint) HierarchicalPartitioner.this.lcMap.inverse().get(relationship.getStart()), (Cone) HierarchicalPartitioner.this.coneMap.inverse().get(relationship.getEnd()));
                    }
                });
            } catch (IllegalArgumentException e) {
                log.debug("Caught illegal argument exception.");
                if (this.restrictedConstraints.size() <= 1) {
                    throw e;
                }
                int min = Math.min((int) Math.ceil(this.restrictedConstraints.size() / 3), this.restrictedConstraints.size() - 1);
                Iterator<LinearConstraint> it4 = this.restrictedConstraints.iterator();
                for (int i = 0; i < min; i++) {
                    it4.next();
                    it4.remove();
                }
                z = true;
            }
            log.trace("Partition finished. Checking for balance.");
            boolean z2 = true;
            if (!z && ceil > 1) {
                int i2 = 0;
                Iterator<List<Node>> it5 = list.iterator();
                while (it5.hasNext()) {
                    i2 += it5.next().size();
                }
                Iterator<List<Node>> it6 = list.iterator();
                while (true) {
                    if (!it6.hasNext()) {
                        break;
                    }
                    List<Node> next = it6.next();
                    if (next.size() > 2 * (i2 - next.size())) {
                        log.debug("{} > {}", Integer.valueOf(next.size()), Integer.valueOf(2 * (i2 - next.size())));
                        z2 = false;
                    }
                    if (!z2) {
                        z = true;
                        break;
                    }
                }
                if (!z2) {
                    z = true;
                    if (this.restrictedConstraints.size() > 1 && this.restrictedConstraints.size() > this.alwaysCutConstraints.size() / 2) {
                        Iterator<LinearConstraint> it7 = this.restrictedConstraints.iterator();
                        while (this.restrictedConstraints.size() > this.alwaysCutConstraints.size() / 1.5d) {
                            it7.next();
                            it7.remove();
                        }
                    }
                }
            }
            if (z) {
                log.debug("Redoing partition {}.", Integer.valueOf(this.p));
            } else {
                Vector vector = new Vector();
                for (int i3 = 0; i3 < list.size(); i3++) {
                    HashSet hashSet2 = new HashSet();
                    for (Node node : list.get(i3)) {
                        if (this.coneMap.containsValue(node)) {
                            hashSet2.add((Cone) this.coneMap.inverse().get(node));
                        }
                    }
                    vector.add(hashSet2);
                }
                ConicProgramPartition conicProgramPartition = new ConicProgramPartition(this.program, vector);
                log.debug("Size of cut constraints: {}", Integer.valueOf(conicProgramPartition.getCutConstraints().size()));
                this.partitions.add(conicProgramPartition);
                if (this.p == 0) {
                    this.alwaysCutConstraints.addAll(conicProgramPartition.getCutConstraints());
                    this.restrictedConstraints.addAll(conicProgramPartition.getCutConstraints());
                } else {
                    this.alwaysCutConstraints.retainAll(conicProgramPartition.getCutConstraints());
                    this.restrictedConstraints.clear();
                    this.restrictedConstraints.addAll(this.alwaysCutConstraints);
                }
                HashSet hashSet3 = new HashSet();
                ArrayList arrayList = new ArrayList();
                for (LinearConstraint linearConstraint2 : conicProgramPartition.getCutConstraints()) {
                    hashSet3.clear();
                    arrayList.clear();
                    Iterator<Variable> it8 = linearConstraint2.getVariables().keySet().iterator();
                    while (it8.hasNext()) {
                        Cone cone = it8.next().getCone();
                        hashSet3.add(conicProgramPartition.getElement(cone));
                        if (isSingleton(cone)) {
                            conicProgramPartition.removeCone(cone);
                            arrayList.add(cone);
                        }
                    }
                    if (arrayList.size() < hashSet3.size()) {
                        log.warn("Not enough singletons to cut constraint. Needed {}.", Integer.valueOf(hashSet3.size()));
                    }
                    Iterator it9 = hashSet3.iterator();
                    Iterator it10 = arrayList.iterator();
                    while (it10.hasNext()) {
                        conicProgramPartition.addCone((Cone) it10.next(), (Integer) it9.next());
                        if (!it9.hasNext()) {
                            it9 = hashSet3.iterator();
                        }
                    }
                }
                processAcceptedPartition();
                log.debug("Number of always cut constraints: {}", Integer.valueOf(this.alwaysCutConstraints.size()));
                this.p++;
            }
            if (this.alwaysCutConstraints.size() <= 0 && !z) {
                return;
            }
        }
    }

    @Override // edu.umd.cs.psl.optimizer.conic.program.ConicProgramListener
    public void notify(ConicProgram conicProgram, ConicProgramEvent conicProgramEvent, Entity entity, Object... objArr) {
    }

    protected abstract double getWeight(LinearConstraint linearConstraint, Cone cone);

    protected abstract void processAcceptedPartition();

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isSingleton(Cone cone) {
        if (cone instanceof NonNegativeOrthantCone) {
            return ((NonNegativeOrthantCone) cone).getVariable().getLinearConstraints().size() == 1;
        }
        if (!(cone instanceof SecondOrderCone)) {
            throw new IllegalStateException();
        }
        LinearConstraint linearConstraint = null;
        Iterator<Variable> it = ((SecondOrderCone) cone).getVariables().iterator();
        while (it.hasNext()) {
            Set<LinearConstraint> linearConstraints = it.next().getLinearConstraints();
            if (linearConstraints.size() > 1) {
                return false;
            }
            if (linearConstraints.size() == 1) {
                if (linearConstraint == null) {
                    linearConstraint = linearConstraints.iterator().next();
                } else if (!linearConstraint.equals(linearConstraints.iterator().next())) {
                    return false;
                }
            }
        }
        return true;
    }
}
