/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.eval;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.eval.KnnEvaluator;
import org.apache.solr.client.solrj.io.eval.ManyValueWorker;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.eval.MinMaxScaleEvaluator;
import org.apache.solr.client.solrj.io.eval.RecursiveObjectEvaluator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;

public class KnnRegressionEvaluator
extends RecursiveObjectEvaluator
implements ManyValueWorker {
    protected static final long serialVersionUID = 1L;
    private boolean robust = false;
    private boolean scale = false;

    public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
        super(expression, factory);
        List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
        for (StreamExpressionNamedParameter namedParam : namedParams) {
            if (namedParam.getName().equals("scale")) {
                this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
                continue;
            }
            if (namedParam.getName().equals("robust")) {
                this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
                continue;
            }
            throw new IOException("Unexpected named parameter:" + namedParam.getName());
        }
    }

    @Override
    public Object doWork(Object ... values) throws IOException {
        if (values.length < 3) {
            throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k.");
        }
        Matrix observations = null;
        List outcomes = null;
        int k = 5;
        DistanceMeasure distanceMeasure = new EuclideanDistance();
        if (!(values[0] instanceof Matrix)) {
            throw new IOException("The first parameter for knnRegress should be the observation matrix.");
        }
        observations = (Matrix)values[0];
        if (!(values[1] instanceof List)) {
            throw new IOException("The second parameter for knnRegress should be outcome array. ");
        }
        outcomes = (List)values[1];
        if (!(values[2] instanceof Number)) {
            throw new IOException("The third parameter for knnRegress should be k. ");
        }
        k = ((Number)values[2]).intValue();
        if (values.length == 4) {
            if (values[3] instanceof DistanceMeasure) {
                distanceMeasure = (DistanceMeasure)values[3];
            } else {
                throw new IOException("The fourth parameter for knnRegress should be a distance measure. ");
            }
        }
        double[] outcomeData = new double[outcomes.size()];
        for (int i = 0; i < outcomeData.length; ++i) {
            outcomeData[i] = ((Number)outcomes.get(i)).doubleValue();
        }
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("k", k);
        map.put("observations", observations.getRowCount());
        map.put("features", observations.getColumnCount());
        map.put("distance", distanceMeasure.getClass().getSimpleName());
        map.put("robust", this.robust);
        map.put("scale", this.scale);
        return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, this.scale, this.robust);
    }

    public static class KnnRegressionTuple
    extends Tuple {
        private Matrix observations;
        private Matrix scaledObservations;
        private double[] outcomes;
        private int k;
        private DistanceMeasure distanceMeasure;
        private boolean scale;
        private boolean robust;

        public KnnRegressionTuple(Matrix observations, double[] outcomes, int k, DistanceMeasure distanceMeasure, Map<?, ?> map, boolean scale, boolean robust) {
            super(map);
            this.observations = observations;
            this.outcomes = outcomes;
            this.k = k;
            this.distanceMeasure = distanceMeasure;
            this.scale = scale;
            this.robust = robust;
        }

        public boolean getScale() {
            return this.scale;
        }

        public double[] scale(double[] predictors) {
            double[][] data = this.observations.getData();
            Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data);
            Array2DRowRealMatrix transposed = (Array2DRowRealMatrix)matrix.transpose();
            double[][] featureRows = transposed.getDataRef();
            double[] scaledPredictors = new double[predictors.length];
            for (int i = 0; i < featureRows.length; ++i) {
                double[] featureRow = featureRows[i];
                double[] combinedFeatureRow = new double[featureRow.length + 1];
                System.arraycopy(featureRow, 0, combinedFeatureRow, 0, featureRow.length);
                combinedFeatureRow[featureRow.length] = predictors[i];
                double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0.0, 1.0);
                scaledPredictors[i] = scaledFeatures[featureRow.length];
                System.arraycopy(scaledFeatures, 0, featureRow, 0, featureRow.length);
            }
            Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(featureRows);
            Array2DRowRealMatrix scaledObservationsMatrix = (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
            this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
            return scaledPredictors;
        }

        public Matrix scale(Matrix predictors) {
            double[][] observationData = this.observations.getData();
            Array2DRowRealMatrix observationMatrix = new Array2DRowRealMatrix(observationData);
            Array2DRowRealMatrix observationTransposed = (Array2DRowRealMatrix)observationMatrix.transpose();
            double[][] observationFeatureRows = observationTransposed.getDataRef();
            double[][] predictorsData = predictors.getData();
            Array2DRowRealMatrix predictorMatrix = new Array2DRowRealMatrix(predictorsData);
            Array2DRowRealMatrix predictorTransposed = (Array2DRowRealMatrix)predictorMatrix.transpose();
            double[][] predictorFeatureRows = predictorTransposed.getDataRef();
            for (int i = 0; i < observationFeatureRows.length; ++i) {
                double[] observationFeatureRow = observationFeatureRows[i];
                double[] predictorFeatureRow = predictorFeatureRows[i];
                double[] combinedFeatureRow = new double[observationFeatureRow.length + predictorFeatureRow.length];
                System.arraycopy(observationFeatureRow, 0, combinedFeatureRow, 0, observationFeatureRow.length);
                System.arraycopy(predictorFeatureRow, 0, combinedFeatureRow, observationFeatureRow.length, predictorFeatureRow.length);
                double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0.0, 1.0);
                System.arraycopy(scaledFeatures, 0, observationFeatureRow, 0, observationFeatureRow.length);
                System.arraycopy(scaledFeatures, observationFeatureRow.length, predictorFeatureRow, 0, predictorFeatureRow.length);
            }
            Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(observationFeatureRows);
            Array2DRowRealMatrix scaledObservationsMatrix = (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
            this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
            Array2DRowRealMatrix scaledPredictorMatrix = new Array2DRowRealMatrix(predictorFeatureRows);
            Array2DRowRealMatrix scaledTransposedPredictorMatrix = (Array2DRowRealMatrix)scaledPredictorMatrix.transpose();
            return new Matrix(scaledTransposedPredictorMatrix.getDataRef());
        }

        public double predict(double[] values) {
            Matrix obs = this.scaledObservations != null ? this.scaledObservations : this.observations;
            Matrix knn = KnnEvaluator.search(obs, values, this.k, this.distanceMeasure);
            List indexes = (List)knn.getAttribute("indexes");
            if (this.robust) {
                double[] vals = new double[indexes.size()];
                Percentile percentile = new Percentile();
                int i = 0;
                for (Number n : indexes) {
                    vals[i++] = this.outcomes[n.intValue()];
                }
                return percentile.evaluate(vals, 50.0);
            }
            double sum = 0.0;
            for (Number n : indexes) {
                sum += this.outcomes[n.intValue()];
            }
            return sum / (double)indexes.size();
        }
    }
}

