/*
 * Decompiled with CFR 0.152.
 */
package org.talend.dataquality.parsing.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.talend.dataquality.parsing.util.OverallMetrics;
import org.talend.dataquality.parsing.util.TagWiseMetrics;

public class CrfEvaluator {
    private static List<List<String>> predictions = new ArrayList<List<String>>();
    private static List<List<String>> references = new ArrayList<List<String>>();
    private static boolean hasComputed = false;
    private static final Logger LOGGER = LoggerFactory.getLogger(CrfEvaluator.class);
    private static Map<String, TagWiseMetrics> tagWiseMetricsMap = new HashMap<String, TagWiseMetrics>();
    private static OverallMetrics metrics = new OverallMetrics();

    public static void init(List<List<String>> predictions, List<List<String>> references) {
        if (predictions == null || predictions.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of predictions");
        }
        if (references == null || references.isEmpty()) {
            throw new IllegalArgumentException("NULL or empty list of references");
        }
        if (references.size() != predictions.size()) {
            throw new IllegalArgumentException("the size of prediction list should equals to the size of real tag list");
        }
        CrfEvaluator.predictions.addAll(predictions);
        CrfEvaluator.references.addAll(references);
        metrics.setInstTotalNum(CrfEvaluator.references.size());
    }

    public static void reset() {
        hasComputed = false;
        if (!predictions.isEmpty()) {
            predictions.clear();
        }
        if (!references.isEmpty()) {
            references.clear();
        }
        tagWiseMetricsMap.clear();
        metrics.setNumTags(0);
        metrics.setItemTotalCorrect(0);
        metrics.setItemTotalNum(0);
        metrics.setItemTotalObservation(0);
        metrics.setItemTotalPrediction(0);
        metrics.setItemAccuracy(0.0);
        metrics.setInstTotalCorrect(0);
        metrics.setInstTotalNum(0);
        metrics.setInstAccuracy(0.0);
        metrics.setMacroPrecision(0.0);
        metrics.setMacroRecall(0.0);
        metrics.setMacroFMeasure(0.0);
        metrics.setMicroPrecision(0.0);
        metrics.setMicroRecall(0.0);
        metrics.setMicroFMeasure(0.0);
    }

    public static void compute() {
        for (int i = 0; i < metrics.getInstTotalNum(); ++i) {
            List<String> predTags = predictions.get(i);
            List<String> refTags = references.get(i);
            if (predTags.size() != refTags.size()) {
                LOGGER.warn("Exclude instance " + i + ": {Predict tags = " + predTags + ", Reference tags = " + refTags + "}\nCaused by #Predicts not equals to #References (" + predTags.size() + " != " + refTags.size() + ")");
                continue;
            }
            CrfEvaluator.accumulate(predTags, refTags);
        }
        if (tagWiseMetricsMap.size() == 0) {
            LOGGER.error("Exit! Impossible to compute the metrics, all instances are invalid.");
            return;
        }
        double macroPrecision = 0.0;
        double macroRecall = 0.0;
        double macroFMeasure = 0.0;
        metrics.setNumTags(tagWiseMetricsMap.keySet().size());
        for (String tag : tagWiseMetricsMap.keySet()) {
            TagWiseMetrics tagWiseMetrics = tagWiseMetricsMap.get(tag);
            int numCorrect = tagWiseMetrics.getNum_correct();
            int numPrediction = tagWiseMetrics.getNum_prediction();
            int numObservation = tagWiseMetrics.getNum_observation();
            metrics.setItemTotalCorrect(metrics.getItemTotalCorrect() + numCorrect);
            metrics.setItemTotalPrediction(metrics.getItemTotalPrediction() + numPrediction);
            metrics.setItemTotalObservation(metrics.getItemTotalObservation() + numObservation);
            if (numObservation == 0) continue;
            double precision = 0.0;
            double recall = 0.0;
            if (numPrediction > 0) {
                precision = (double)numCorrect / (double)numPrediction;
                tagWiseMetrics.setPrecision(precision);
            }
            if (numObservation > 0) {
                recall = (double)numCorrect / (double)numObservation;
                tagWiseMetrics.setRecall(recall);
            }
            if (numPrediction + numObservation > 0) {
                tagWiseMetrics.setFmeasure((double)(2 * numCorrect) / (double)(numPrediction + numObservation));
            }
            macroPrecision += metrics.getMacroPrecision() + precision;
            macroRecall += metrics.getMacroRecall() + recall;
            macroFMeasure += metrics.getMacroFMeasure() + tagWiseMetrics.getFmeasure();
            tagWiseMetricsMap.replace(tag, tagWiseMetrics);
        }
        metrics.setMacroPrecision(macroPrecision / (double)metrics.getNumTags());
        metrics.setMacroRecall(macroRecall / (double)metrics.getNumTags());
        metrics.setMacroFMeasure(macroFMeasure / (double)metrics.getNumTags());
        if (metrics.getItemTotalPrediction() > 0) {
            metrics.setMicroPrecision((double)metrics.getItemTotalCorrect() / (double)metrics.getItemTotalPrediction());
        }
        if (metrics.getItemTotalObservation() > 0) {
            metrics.setMicroRecall((double)metrics.getItemTotalCorrect() / (double)metrics.getItemTotalObservation());
        }
        if (metrics.getItemTotalPrediction() + metrics.getItemTotalObservation() > 0) {
            metrics.setMicroFMeasure((double)(2 * metrics.getItemTotalCorrect()) / (double)(metrics.getItemTotalPrediction() + metrics.getItemTotalObservation()));
        }
        if (metrics.getItemTotalNum() > 0) {
            metrics.setItemAccuracy((double)metrics.getItemTotalCorrect() / (double)metrics.getItemTotalNum());
        }
        if (metrics.getInstTotalNum() > 0) {
            metrics.setInstAccuracy((double)metrics.getInstTotalCorrect() / (double)metrics.getInstTotalNum());
        }
        hasComputed = true;
    }

    private static void accumulate(List<String> predTags, List<String> refTags) {
        int count = 0;
        for (int i = 0; i < refTags.size(); ++i) {
            String ref = refTags.get(i);
            String pred = predTags.get(i);
            CrfEvaluator.increaseTagWiseMetricsMap(ref, "num_observation");
            CrfEvaluator.increaseTagWiseMetricsMap(pred, "num_prediction");
            if (ref.equals(pred)) {
                CrfEvaluator.increaseTagWiseMetricsMap(ref, "num_correct");
                ++count;
            }
            metrics.setItemTotalNum(metrics.getItemTotalNum() + 1);
        }
        if (count == refTags.size()) {
            metrics.setInstTotalCorrect(metrics.getInstTotalCorrect() + 1);
        }
    }

    private static void increaseTagWiseMetricsMap(String tag, String numName) {
        TagWiseMetrics wiseMetrics = tagWiseMetricsMap.getOrDefault(tag, new TagWiseMetrics());
        wiseMetrics.increaseNum(numName);
        if (tagWiseMetricsMap.containsKey(tag)) {
            tagWiseMetricsMap.replace(tag, wiseMetrics);
        } else {
            tagWiseMetricsMap.put(tag, wiseMetrics);
        }
    }

    public static String resultsAsString() {
        if (!hasComputed) {
            LOGGER.error("The metrics haven't been computed or have been reset, please call Evaluator.compute()");
            return "";
        }
        String lineSep = System.getProperty("line.separator");
        StringBuilder outDisplay = new StringBuilder(lineSep);
        outDisplay.append("+---------------------------------------------------------------------------------------------------------+\n");
        outDisplay.append("|                                            Model Performance                                            |\n");
        outDisplay.append("+---------------------------------------------------------------------------------------------------------+\n");
        outDisplay.append(" Overall accuracy:");
        outDisplay.append(String.format(" - Item accuracy: %d / %d (%1.3f)%n", CrfEvaluator.metrics.getItemTotalCorrect(), CrfEvaluator.metrics.getItemTotalNum(), CrfEvaluator.metrics.getItemAccuracy()));
        outDisplay.append(String.format(" - Instance accuracy: %d / %d (%1.3f)%n", CrfEvaluator.metrics.getInstTotalCorrect(), CrfEvaluator.metrics.getInstTotalNum(), CrfEvaluator.metrics.getInstAccuracy()));
        outDisplay.append(lineSep);
        outDisplay.append(" Performance by label:\n ");
        outDisplay.append(String.format(" %15s%15s%15s%15s%15s%15s%15s%n", "   ", "Precision", "Recall", "F1-measure", "#Match", "#Predict", "#Ref"));
        tagWiseMetricsMap.forEach((tagName, metrics) -> {
            if (metrics != null) {
                outDisplay.append(String.format(" %15s%15.3f%15.3f%15.3f%15d%15d%15d%n", "\"" + tagName + "\"", metrics.getPrecision(), metrics.getRecall(), metrics.getFmeasure(), metrics.getNum_correct(), metrics.getNum_prediction(), metrics.getNum_observation()));
            }
        });
        outDisplay.append(String.format("%n %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Macro-average", CrfEvaluator.metrics.getMacroPrecision(), CrfEvaluator.metrics.getMacroRecall(), CrfEvaluator.metrics.getMacroFMeasure(), "-", "-", "-"));
        outDisplay.append(String.format(" %15s%15.3f%15.3f%15.3f%15s%15s%15s%n", "Micro-average", CrfEvaluator.metrics.getMicroPrecision(), CrfEvaluator.metrics.getMicroRecall(), CrfEvaluator.metrics.getMicroFMeasure(), "-", "-", "-"));
        outDisplay.append("===========================================================================================================");
        outDisplay.append(lineSep);
        return outDisplay.toString();
    }

    public static OverallMetrics getMetrics() {
        LOGGER.info(metrics.toString());
        return metrics;
    }

    public static Map<String, TagWiseMetrics> getTagWiseMetricsMap() {
        LOGGER.info(tagWiseMetricsMap.toString());
        return tagWiseMetricsMap;
    }
}

