/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.evaluation;

import java.io.InputStreamReader;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.TwoClassStats;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

public class ThresholdCurve
implements RevisionHandler {
    public static final String RELATION_NAME = "ThresholdCurve";
    public static final String TRUE_POS_NAME = "True Positives";
    public static final String FALSE_NEG_NAME = "False Negatives";
    public static final String FALSE_POS_NAME = "False Positives";
    public static final String TRUE_NEG_NAME = "True Negatives";
    public static final String FP_RATE_NAME = "False Positive Rate";
    public static final String TP_RATE_NAME = "True Positive Rate";
    public static final String PRECISION_NAME = "Precision";
    public static final String RECALL_NAME = "Recall";
    public static final String FALLOUT_NAME = "Fallout";
    public static final String FMEASURE_NAME = "FMeasure";
    public static final String SAMPLE_SIZE_NAME = "Sample Size";
    public static final String LIFT_NAME = "Lift";
    public static final String THRESHOLD_NAME = "Threshold";

    public Instances getCurve(FastVector predictions) {
        if (predictions.size() == 0) {
            return null;
        }
        return this.getCurve(predictions, ((NominalPrediction)predictions.elementAt(0)).distribution().length - 1);
    }

    public Instances getCurve(FastVector predictions, int classIndex) {
        if (predictions.size() == 0 || ((NominalPrediction)predictions.elementAt(0)).distribution().length <= classIndex) {
            return null;
        }
        double totPos = 0.0;
        double totNeg = 0.0;
        double[] probs = this.getProbabilities(predictions, classIndex);
        for (int i = 0; i < probs.length; ++i) {
            NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);
            if (pred.actual() == Prediction.MISSING_VALUE) {
                System.err.println(this.getClass().getName() + " Skipping prediction with missing class value");
                continue;
            }
            if (pred.weight() < 0.0) {
                System.err.println(this.getClass().getName() + " Skipping prediction with negative weight");
                continue;
            }
            if (pred.actual() == (double)classIndex) {
                totPos += pred.weight();
                continue;
            }
            totNeg += pred.weight();
        }
        Instances insts = this.makeHeader();
        int[] sorted = Utils.sort(probs);
        TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0.0, 0.0);
        double threshold = 0.0;
        double cumulativePos = 0.0;
        double cumulativeNeg = 0.0;
        for (int i = 0; i < sorted.length; ++i) {
            NominalPrediction pred;
            if (i == 0 || probs[sorted[i]] > threshold) {
                tc.setTruePositive(tc.getTruePositive() - cumulativePos);
                tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
                tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
                tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
                threshold = probs[sorted[i]];
                insts.add(this.makeInstance(tc, threshold));
                cumulativePos = 0.0;
                cumulativeNeg = 0.0;
                if (i == sorted.length - 1) break;
            }
            if ((pred = (NominalPrediction)predictions.elementAt(sorted[i])).actual() == Prediction.MISSING_VALUE) {
                System.err.println(this.getClass().getName() + " Skipping prediction with missing class value");
                continue;
            }
            if (pred.weight() < 0.0) {
                System.err.println(this.getClass().getName() + " Skipping prediction with negative weight");
                continue;
            }
            if (pred.actual() == (double)classIndex) {
                cumulativePos += pred.weight();
                continue;
            }
            cumulativeNeg += pred.weight();
        }
        return insts;
    }

    public static double getNPointPrecision(Instances tcurve, int n) {
        if (!RELATION_NAME.equals(tcurve.relationName()) || tcurve.numInstances() == 0) {
            return Double.NaN;
        }
        int recallInd = tcurve.attribute(RECALL_NAME).index();
        int precisInd = tcurve.attribute(PRECISION_NAME).index();
        double[] recallVals = tcurve.attributeToDoubleArray(recallInd);
        int[] sorted = Utils.sort(recallVals);
        double isize = 1.0 / (double)(n - 1);
        double psum = 0.0;
        for (int i = 0; i < n; ++i) {
            int pos = ThresholdCurve.binarySearch(sorted, recallVals, (double)i * isize);
            double recall = recallVals[sorted[pos]];
            double precis = tcurve.instance(sorted[pos]).value(precisInd);
            while (pos != 0 && pos < sorted.length - 1) {
                double recall2;
                if ((recall2 = recallVals[sorted[++pos]]) == recall) continue;
                double precis2 = tcurve.instance(sorted[pos]).value(precisInd);
                double slope = (precis2 - precis) / (recall2 - recall);
                double offset = precis - recall * slope;
                precis = isize * (double)i * slope + offset;
                break;
            }
            psum += precis;
        }
        return psum / (double)n;
    }

    public static double getROCArea(Instances tcurve) {
        int n = tcurve.numInstances();
        if (!RELATION_NAME.equals(tcurve.relationName()) || n == 0) {
            return Double.NaN;
        }
        int tpInd = tcurve.attribute(TRUE_POS_NAME).index();
        int fpInd = tcurve.attribute(FALSE_POS_NAME).index();
        double[] tpVals = tcurve.attributeToDoubleArray(tpInd);
        double[] fpVals = tcurve.attributeToDoubleArray(fpInd);
        double area = 0.0;
        double cumNeg = 0.0;
        double totalPos = tpVals[0];
        double totalNeg = fpVals[0];
        for (int i = 0; i < n; ++i) {
            double cin;
            double cip;
            if (i < n - 1) {
                cip = tpVals[i] - tpVals[i + 1];
                cin = fpVals[i] - fpVals[i + 1];
            } else {
                cip = tpVals[n - 1];
                cin = fpVals[n - 1];
            }
            area += cip * (cumNeg + 0.5 * cin);
            cumNeg += cin;
        }
        return area /= totalNeg * totalPos;
    }

    public static int getThresholdInstance(Instances tcurve, double threshold) {
        if (!RELATION_NAME.equals(tcurve.relationName()) || tcurve.numInstances() == 0 || threshold < 0.0 || threshold > 1.0) {
            return -1;
        }
        if (tcurve.numInstances() == 1) {
            return 0;
        }
        double[] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1);
        int[] sorted = Utils.sort(tvals);
        return ThresholdCurve.binarySearch(sorted, tvals, threshold);
    }

    private static int binarySearch(int[] index, double[] vals, double target) {
        int lo = 0;
        int hi = index.length - 1;
        while (hi - lo > 1) {
            int mid = lo + (hi - lo) / 2;
            double midval = vals[index[mid]];
            if (target > midval) {
                lo = mid;
                continue;
            }
            if (target < midval) {
                hi = mid;
                continue;
            }
            while (mid > 0 && vals[index[mid - 1]] == target) {
                --mid;
            }
            return mid;
        }
        return lo;
    }

    private double[] getProbabilities(FastVector predictions, int classIndex) {
        double[] probs = new double[predictions.size()];
        for (int i = 0; i < probs.length; ++i) {
            NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);
            probs[i] = pred.distribution()[classIndex];
        }
        return probs;
    }

    private Instances makeHeader() {
        FastVector<Attribute> fv = new FastVector<Attribute>();
        fv.addElement(new Attribute(TRUE_POS_NAME));
        fv.addElement(new Attribute(FALSE_NEG_NAME));
        fv.addElement(new Attribute(FALSE_POS_NAME));
        fv.addElement(new Attribute(TRUE_NEG_NAME));
        fv.addElement(new Attribute(FP_RATE_NAME));
        fv.addElement(new Attribute(TP_RATE_NAME));
        fv.addElement(new Attribute(PRECISION_NAME));
        fv.addElement(new Attribute(RECALL_NAME));
        fv.addElement(new Attribute(FALLOUT_NAME));
        fv.addElement(new Attribute(FMEASURE_NAME));
        fv.addElement(new Attribute(SAMPLE_SIZE_NAME));
        fv.addElement(new Attribute(LIFT_NAME));
        fv.addElement(new Attribute(THRESHOLD_NAME));
        return new Instances(RELATION_NAME, fv, 100);
    }

    private Instance makeInstance(TwoClassStats tc, double prob) {
        int count = 0;
        double[] vals = new double[13];
        vals[count++] = tc.getTruePositive();
        vals[count++] = tc.getFalseNegative();
        vals[count++] = tc.getFalsePositive();
        vals[count++] = tc.getTrueNegative();
        vals[count++] = tc.getFalsePositiveRate();
        vals[count++] = tc.getTruePositiveRate();
        vals[count++] = tc.getPrecision();
        vals[count++] = tc.getRecall();
        vals[count++] = tc.getFallout();
        vals[count++] = tc.getFMeasure();
        double ss = (tc.getTruePositive() + tc.getFalsePositive()) / (tc.getTruePositive() + tc.getFalsePositive() + tc.getTrueNegative() + tc.getFalseNegative());
        vals[count++] = ss;
        double expectedByChance = ss * (tc.getTruePositive() + tc.getFalseNegative());
        vals[count++] = expectedByChance < 1.0 ? Utils.missingValue() : tc.getTruePositive() / expectedByChance;
        vals[count++] = prob;
        return new DenseInstance(1.0, vals);
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5987 $");
    }

    public static void main(String[] args) {
        try {
            Instances inst = new Instances(new InputStreamReader(System.in));
            inst.setClassIndex(inst.numAttributes() - 1);
            ThresholdCurve tc = new ThresholdCurve();
            EvaluationUtils eu = new EvaluationUtils();
            Logistic classifier = new Logistic();
            FastVector predictions = new FastVector();
            for (int i = 0; i < 2; ++i) {
                eu.setSeed(i);
                predictions.appendElements(eu.getCVPredictions(classifier, inst, 10));
            }
            Instances result = tc.getCurve(predictions);
            System.out.println(result);
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }
}

