package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/tree/LambdaMART.class */
public class LambdaMART extends Ranker {
    public static int nTrees = 1000;
    public static float learningRate = 0.1f;
    public static int nThreshold = 256;
    public static int nRoundToStopEarly = 100;
    public static int nTreeLeaves = 10;
    public static int minLeafSupport = 1;
    public static int gcCycle = 100;
    protected float[][] thresholds;
    protected Ensemble ensemble;
    protected double[] modelScores;
    protected double[][] modelScoresOnValidation;
    protected int bestModelOnValidation;
    protected DataPoint[] martSamples;
    protected int[][] sortedIdx;
    protected FeatureHistogram hist;
    protected double[] pseudoResponses;
    protected double[] weights;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ciir/umass/edu/learning/tree/LambdaMART$LambdaComputationWorker.class */
    public class LambdaComputationWorker implements Runnable {
        LambdaMART ranker;
        int rlStart;
        int rlEnd;
        int martStart;

        LambdaComputationWorker(LambdaMART lambdaMART, int i, int i2, int i3) {
            this.ranker = null;
            this.rlStart = -1;
            this.rlEnd = -1;
            this.martStart = -1;
            this.ranker = lambdaMART;
            this.rlStart = i;
            this.rlEnd = i2;
            this.martStart = i3;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.ranker.computePseudoResponses(this.rlStart, this.rlEnd, this.martStart);
        }
    }

    /* loaded from: input_file:ciir/umass/edu/learning/tree/LambdaMART$SortWorker.class */
    class SortWorker implements Runnable {
        LambdaMART ranker;
        int start;
        int end;

        SortWorker(LambdaMART lambdaMART, int i, int i2) {
            this.ranker = null;
            this.start = -1;
            this.end = -1;
            this.ranker = lambdaMART;
            this.start = i;
            this.end = i2;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.ranker.sortSamplesByFeature(this.start, this.end);
        }
    }

    /* loaded from: input_file:ciir/umass/edu/learning/tree/LambdaMART$Worker.class */
    class Worker implements Runnable {
        LambdaMART ranker;
        int rlStart;
        int rlEnd;
        int martStart;
        int type;
        float score;

        Worker(LambdaMART lambdaMART, int i, int i2) {
            this.ranker = null;
            this.rlStart = -1;
            this.rlEnd = -1;
            this.martStart = -1;
            this.type = -1;
            this.score = 0.0f;
            this.type = 3;
            this.ranker = lambdaMART;
            this.rlStart = i;
            this.rlEnd = i2;
        }

        Worker(LambdaMART lambdaMART, int i, int i2, int i3) {
            this.ranker = null;
            this.rlStart = -1;
            this.rlEnd = -1;
            this.martStart = -1;
            this.type = -1;
            this.score = 0.0f;
            this.type = 4;
            this.ranker = lambdaMART;
            this.rlStart = i;
            this.rlEnd = i2;
            this.martStart = i3;
        }

        @Override // java.lang.Runnable
        public void run() {
            if (this.type == 4) {
                this.score = this.ranker.computeModelScoreOnTraining(this.rlStart, this.rlEnd, this.martStart);
            } else if (this.type == 3) {
                this.score = this.ranker.computeModelScoreOnValidation(this.rlStart, this.rlEnd);
            }
        }
    }

    public LambdaMART() {
        this.thresholds = (float[][]) null;
        this.ensemble = null;
        this.modelScores = null;
        this.modelScoresOnValidation = (double[][]) null;
        this.bestModelOnValidation = 2147483645;
        this.martSamples = null;
        this.sortedIdx = (int[][]) null;
        this.hist = null;
        this.pseudoResponses = null;
        this.weights = null;
    }

    public LambdaMART(List<RankList> list, int[] iArr, MetricScorer metricScorer) {
        super(list, iArr, metricScorer);
        this.thresholds = (float[][]) null;
        this.ensemble = null;
        this.modelScores = null;
        this.modelScoresOnValidation = (double[][]) null;
        this.bestModelOnValidation = 2147483645;
        this.martSamples = null;
        this.sortedIdx = (int[][]) null;
        this.hist = null;
        this.pseudoResponses = null;
        this.weights = null;
    }

    /* JADX WARN: Type inference failed for: r1v18, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v30, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r1v43, types: [double[], double[][]] */
    @Override // ciir.umass.edu.learning.Ranker
    public void init() {
        int i;
        PRINT("Initializing... ");
        int i2 = 0;
        for (int i3 = 0; i3 < this.samples.size(); i3++) {
            i2 += this.samples.get(i3).size();
        }
        int i4 = 0;
        this.martSamples = new DataPoint[i2];
        this.modelScores = new double[i2];
        this.pseudoResponses = new double[i2];
        this.weights = new double[i2];
        for (int i5 = 0; i5 < this.samples.size(); i5++) {
            RankList rankList = this.samples.get(i5);
            for (int i6 = 0; i6 < rankList.size(); i6++) {
                this.martSamples[i4 + i6] = rankList.get(i6);
                this.modelScores[i4 + i6] = 0.0d;
                this.pseudoResponses[i4 + i6] = 0.0d;
                this.weights[i4 + i6] = 0.0d;
            }
            i4 += rankList.size();
        }
        this.sortedIdx = new int[this.features.length];
        MyThreadPool myThreadPool = MyThreadPool.getInstance();
        if (myThreadPool.size() == 1) {
            sortSamplesByFeature(0, this.features.length - 1);
        } else {
            int[] partition = myThreadPool.partition(this.features.length);
            for (int i7 = 0; i7 < partition.length - 1; i7++) {
                myThreadPool.execute(new SortWorker(this, partition[i7], partition[i7 + 1] - 1));
            }
            myThreadPool.await();
        }
        this.thresholds = new float[this.features.length];
        for (int i8 = 0; i8 < this.features.length; i8++) {
            ArrayList arrayList = new ArrayList();
            float f = Float.NEGATIVE_INFINITY;
            float f2 = Float.MAX_VALUE;
            for (int i9 = 0; i9 < this.martSamples.length; i9 = (i - 1) + 1) {
                float featureValue = this.martSamples[this.sortedIdx[i8][i9]].getFeatureValue(this.features[i8]);
                arrayList.add(Float.valueOf(featureValue));
                if (f < featureValue) {
                    f = featureValue;
                }
                if (f2 > featureValue) {
                    f2 = featureValue;
                }
                i = i9 + 1;
                while (i < this.martSamples.length && this.martSamples[this.sortedIdx[i8][i]].getFeatureValue(this.features[i8]) <= featureValue) {
                    i++;
                }
            }
            if (arrayList.size() <= nThreshold || nThreshold == -1) {
                this.thresholds[i8] = new float[arrayList.size() + 1];
                for (int i10 = 0; i10 < arrayList.size(); i10++) {
                    this.thresholds[i8][i10] = ((Float) arrayList.get(i10)).floatValue();
                }
                this.thresholds[i8][arrayList.size()] = Float.MAX_VALUE;
            } else {
                float abs = Math.abs(f - f2) / nThreshold;
                this.thresholds[i8] = new float[nThreshold + 1];
                this.thresholds[i8][0] = f2;
                for (int i11 = 1; i11 < nThreshold; i11++) {
                    this.thresholds[i8][i11] = this.thresholds[i8][i11 - 1] + abs;
                }
                this.thresholds[i8][nThreshold] = Float.MAX_VALUE;
            }
        }
        if (this.validationSamples != null) {
            this.modelScoresOnValidation = new double[this.validationSamples.size()];
            for (int i12 = 0; i12 < this.validationSamples.size(); i12++) {
                this.modelScoresOnValidation[i12] = new double[this.validationSamples.get(i12).size()];
                Arrays.fill(this.modelScoresOnValidation[i12], 0.0d);
            }
        }
        this.hist = new FeatureHistogram();
        this.hist.construct(this.martSamples, this.pseudoResponses, this.sortedIdx, this.features, this.thresholds);
        this.sortedIdx = (int[][]) null;
        System.gc();
        PRINTLN("[Done]");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void learn() {
        this.ensemble = new Ensemble();
        PRINTLN("---------------------------------");
        PRINTLN("Training starts...");
        PRINTLN("---------------------------------");
        PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        PRINTLN("---------------------------------");
        for (int i = 0; i < nTrees; i++) {
            PRINT(new int[]{7}, new String[]{(i + 1) + ""});
            computePseudoResponses();
            this.hist.update(this.pseudoResponses);
            RegressionTree regressionTree = new RegressionTree(nTreeLeaves, this.martSamples, this.pseudoResponses, this.hist, minLeafSupport);
            regressionTree.fit();
            this.ensemble.add(regressionTree, learningRate);
            updateTreeOutput(regressionTree);
            List<Split> leaves = regressionTree.leaves();
            for (int i2 = 0; i2 < leaves.size(); i2++) {
                Split split = leaves.get(i2);
                for (int i3 : split.getSamples()) {
                    double[] dArr = this.modelScores;
                    dArr[i3] = dArr[i3] + (learningRate * split.getOutput());
                }
            }
            regressionTree.clearSamples();
            if (i % gcCycle == 0) {
                System.gc();
            }
            this.scoreOnTrainingData = computeModelScoreOnTraining();
            PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scoreOnTrainingData, 4) + ""});
            if (this.validationSamples != null) {
                for (int i4 = 0; i4 < this.modelScoresOnValidation.length; i4++) {
                    for (int i5 = 0; i5 < this.modelScoresOnValidation[i4].length; i5++) {
                        double[] dArr2 = this.modelScoresOnValidation[i4];
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] + (learningRate * regressionTree.eval(this.validationSamples.get(i4).get(i5)));
                    }
                }
                double computeModelScoreOnValidation = computeModelScoreOnValidation();
                PRINT(new int[]{9}, new String[]{SimpleMath.round(computeModelScoreOnValidation, 4) + ""});
                if (computeModelScoreOnValidation > this.bestScoreOnValidationData) {
                    this.bestScoreOnValidationData = computeModelScoreOnValidation;
                    this.bestModelOnValidation = this.ensemble.treeCount() - 1;
                }
            }
            PRINTLN("");
            if (i - this.bestModelOnValidation > nRoundToStopEarly) {
                break;
            }
        }
        while (this.ensemble.treeCount() > this.bestModelOnValidation + 1) {
            this.ensemble.remove(this.ensemble.treeCount() - 1);
        }
        this.scoreOnTrainingData = this.scorer.score(rank(this.samples));
        PRINTLN("---------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(this.scorer.name() + " on training data: " + SimpleMath.round(this.scoreOnTrainingData, 4));
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("---------------------------------");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        return this.ensemble.eval(dataPoint);
    }

    @Override // ciir.umass.edu.learning.Ranker
    /* renamed from: clone */
    public Ranker mo3clone() {
        return new LambdaMART();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String toString() {
        return this.ensemble.toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String model() {
        return ((((((("## " + name() + "\n") + "## No. of trees = " + nTrees + "\n") + "## No. of leaves = " + nTreeLeaves + "\n") + "## No. of threshold candidates = " + nThreshold + "\n") + "## Learning rate = " + learningRate + "\n") + "## Stop early = " + nRoundToStopEarly + "\n") + "\n") + toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void load(String str) {
        try {
            String str2 = "";
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "ASCII"));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    this.ensemble = new Ensemble(str2);
                    this.features = this.ensemble.getFeatures();
                    return;
                } else {
                    String trim = readLine.trim();
                    if (trim.length() != 0 && trim.indexOf("##") != 0) {
                        str2 = str2 + trim;
                    }
                }
            }
        } catch (Exception e) {
            System.out.println("Error in LambdaMART::load(): " + e.toString());
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of trees: " + nTrees);
        PRINTLN("No. of leaves: " + nTreeLeaves);
        PRINTLN("No. of threshold candidates: " + nThreshold);
        PRINTLN("Min leaf support: " + minLeafSupport);
        PRINTLN("Learning rate: " + learningRate);
        PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String name() {
        return "LambdaMART";
    }

    public Ensemble getEnsemble() {
        return this.ensemble;
    }

    protected void computePseudoResponses() {
        Arrays.fill(this.pseudoResponses, 0.0d);
        Arrays.fill(this.weights, 0.0d);
        MyThreadPool myThreadPool = MyThreadPool.getInstance();
        if (myThreadPool.size() == 1) {
            computePseudoResponses(0, this.samples.size() - 1, 0);
            return;
        }
        ArrayList arrayList = new ArrayList();
        int[] partition = myThreadPool.partition(this.samples.size());
        int i = 0;
        for (int i2 = 0; i2 < partition.length - 1; i2++) {
            LambdaComputationWorker lambdaComputationWorker = new LambdaComputationWorker(this, partition[i2], partition[i2 + 1] - 1, i);
            arrayList.add(lambdaComputationWorker);
            myThreadPool.execute(lambdaComputationWorker);
            if (i2 < partition.length - 2) {
                for (int i3 = partition[i2]; i3 <= partition[i2 + 1] - 1; i3++) {
                    i += this.samples.get(i3).size();
                }
            }
        }
        myThreadPool.await();
    }

    protected void computePseudoResponses(int i, int i2, int i3) {
        int k = this.scorer.getK();
        for (int i4 = i; i4 <= i2; i4++) {
            RankList rankList = this.samples.get(i4);
            int[] sort = MergeSorter.sort(this.modelScores, i3, (i3 + rankList.size()) - 1, false);
            RankList rankList2 = new RankList(rankList, sort, i3);
            double[][] swapChange = this.scorer.swapChange(rankList2);
            for (int i5 = 0; i5 < rankList2.size(); i5++) {
                DataPoint dataPoint = rankList2.get(i5);
                int i6 = sort[i5];
                for (int i7 = 0; i7 < rankList2.size() && (i5 <= k || i7 <= k); i7++) {
                    DataPoint dataPoint2 = rankList2.get(i7);
                    int i8 = sort[i7];
                    if (dataPoint.getLabel() > dataPoint2.getLabel()) {
                        double abs = Math.abs(swapChange[i5][i7]);
                        if (abs > 0.0d) {
                            double exp = 1.0d / (1.0d + Math.exp(this.modelScores[i6] - this.modelScores[i8]));
                            double d = exp * abs;
                            double[] dArr = this.pseudoResponses;
                            dArr[i6] = dArr[i6] + d;
                            double[] dArr2 = this.pseudoResponses;
                            dArr2[i8] = dArr2[i8] - d;
                            double d2 = exp * (1.0d - exp) * abs;
                            double[] dArr3 = this.weights;
                            dArr3[i6] = dArr3[i6] + d2;
                            double[] dArr4 = this.weights;
                            dArr4[i8] = dArr4[i8] + d2;
                        }
                    }
                }
            }
            i3 += rankList.size();
        }
    }

    protected void updateTreeOutput(RegressionTree regressionTree) {
        List<Split> leaves = regressionTree.leaves();
        for (int i = 0; i < leaves.size(); i++) {
            float f = 0.0f;
            float f2 = 0.0f;
            Split split = leaves.get(i);
            for (int i2 : split.getSamples()) {
                f = (float) (f + this.pseudoResponses[i2]);
                f2 = (float) (f2 + this.weights[i2]);
            }
            if (f2 == 0.0f) {
                split.setOutput(0.0f);
            } else {
                split.setOutput(f / f2);
            }
        }
    }

    protected int[] sortSamplesByFeature(DataPoint[] dataPointArr, int i) {
        double[] dArr = new double[dataPointArr.length];
        for (int i2 = 0; i2 < dataPointArr.length; i2++) {
            dArr[i2] = dataPointArr[i2].getFeatureValue(i);
        }
        return MergeSorter.sort(dArr, true);
    }

    protected RankList rank(int i, int i2) {
        RankList rankList = this.samples.get(i);
        double[] dArr = new double[rankList.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = this.modelScores[i2 + i3];
        }
        return new RankList(rankList, MergeSorter.sort(dArr, false));
    }

    protected float computeModelScoreOnTraining() {
        return computeModelScoreOnTraining(0, this.samples.size() - 1, 0) / this.samples.size();
    }

    protected float computeModelScoreOnTraining(int i, int i2, int i3) {
        float f = 0.0f;
        int i4 = i3;
        for (int i5 = i; i5 <= i2; i5++) {
            f = (float) (f + this.scorer.score(rank(i5, i4)));
            i4 += this.samples.get(i5).size();
        }
        return f;
    }

    protected float computeModelScoreOnValidation() {
        return computeModelScoreOnValidation(0, this.validationSamples.size() - 1) / this.validationSamples.size();
    }

    protected float computeModelScoreOnValidation(int i, int i2) {
        float f = 0.0f;
        for (int i3 = i; i3 <= i2; i3++) {
            f = (float) (f + this.scorer.score(new RankList(this.validationSamples.get(i3), MergeSorter.sort(this.modelScoresOnValidation[i3], false))));
        }
        return f;
    }

    protected void sortSamplesByFeature(int i, int i2) {
        for (int i3 = i; i3 <= i2; i3++) {
            this.sortedIdx[i3] = sortSamplesByFeature(this.martSamples, this.features[i3]);
        }
    }
}
