import { Dispatch, SetStateAction, useEffect, useState, MouseEvent } from "react";
import { ParsedValidationStats, ValidationStats } from "../../types/predictions.types";
import { convertDecimal, searchItems, sortByColumn } from "../../data/functions";
import { Header } from "../../types/all.types";
import { singleClassificationModelHeaders } from "../../data/models";
import { rgbColors, trColors } from "../../data/functions";


const useResultsetValidations = ({ valid_stats }: { valid_stats: ValidationStats[] }) => {
    const [ stats, setStats ] = useState<ParsedValidationStats[]>([]);
    const [ modifiedStats, setModifiedStats ] = useState<ParsedValidationStats[]>([]);
    const [ headers, setHeaders ] = useState<Header[]>([]);
    const [barData, setBarData] = useState<{x: string[], y: number[], name: string, type: string}[]>([]);
    const [ radarData, setRadarData ] = useState<{r: any[], theta: any[], name: string, type: string, fill: string}[]>([]);
    const [ methodPlots, setMethodPlots ] = useState<{ name: string, index: number, graphs: string[] }[]>([]);
    const [ modelName, setModelName ] = useState('');


    useEffect(() => {
        if (valid_stats && valid_stats.length > 0) {
            const names = new Set(valid_stats.map(stat => stat.model.split('/')[1]));
            let records: ParsedValidationStats[] = [];
            let plots: { name: string, index: number, graphs: string[] }[] = [];

            setModelName(valid_stats[0].model.split('/')[0]);

            names.forEach(name => {
                const match = valid_stats.find(stat => name === stat.model.split('/')[1]);
                if (match) {

                    const method = match.model.split('/')[1];

                    records.push({
                        method_name: method,
                        acc: convertDecimal(match.metrics.acc),
                        auc: convertDecimal(match.metrics.auc),
                        cohens_kappa: convertDecimal(match.metrics.cohens_kappa),
                        f1score: convertDecimal(match.metrics.f1score),
                        mcc: convertDecimal(match.metrics.mcc),
                        precision: convertDecimal(match.metrics.precision),
                        recall: convertDecimal(match.metrics.recall),
                        specificity: convertDecimal(match.metrics.specificity),
                        fn: match.metrics.fn,
                        fp: match.metrics.fp, 
                        fpr: match.metrics.fpr,
                        tn: match.metrics.tn,
                        tp: match.metrics.tp,
                        tpr: match.metrics.tpr,
                        y_pred: match.metrics.y_pred,
                        y_prob: match.metrics.y_prob,
                        y_true: match.metrics.y_true,
                    });

                    plots.push({
                        name: method,
                        index: 0,
                        graphs: ['overlay', 'singleRoc', 'truthTable']
                    })
                }
            })
        
            setStats(records);
            setModifiedStats(records);
            setMethodPlots(plots);

            setHeaders(singleClassificationModelHeaders);
        }
    }, [valid_stats]);

    useEffect(() => {

        if (stats.length > 0) {

            const sorted = stats.sort((a,b) => {
                if (a.method_name.toLowerCase() < b.method_name.toLowerCase()) {
                    return -1
                } else if (a.method_name.toLowerCase() > b.method_name.toLowerCase()) {
                    return 1
                } else {
                    return 0
                }
            }).filter(m => m.method_name !== 'min');


            const methodArr = ['auc', 'acc', 'cohens_kappa', 'f1score', 'mcc', 'precision', 'recall', 'specificity'];

            const arr = sorted.map(m => ({
                x: methodArr,
                y: methodArr.map(method => m[method as keyof object]),
                name: m.method_name,
                type: 'bar',
                marker: {
                    color: trColors[m.method_name as keyof object],
                }
            }));

            const radarArr = sorted.map(m => ({
                type: 'scatterpolar',
                r: methodArr.map(method => m[method as keyof object]),
                theta: methodArr,
                fill: 'toself',
                name: m.method_name,
                marker: {
                    color: trColors[m.method_name as keyof object],
                },
                fillcolor: rgbColors[m.method_name as keyof object],
            }));

            setRadarData(radarArr);
            setBarData(arr);
    
        }

    }, [stats]);

    const next = (e:MouseEvent<HTMLButtonElement>, name: string) => {
        e.preventDefault();
        const method = methodPlots.find(plot => plot.name === name);
        if (method) {
            if (method.index === method.graphs.length-1) {
                return;
            } else {
                setMethodPlots(methodPlots.map(plot => {
                    if (plot.name === name) {
                        return {...plot, index: plot.index+1}
                    } else {
                        return {...plot}
                    }
                }))
            }
        }
    }

    const back = (e:MouseEvent<HTMLButtonElement>, name: string) => {
        e.preventDefault();
        const method = methodPlots.find(plot => plot.name === name);
        if (method) {
            if (method.index === 0) {
                return;
            } else {
                setMethodPlots(methodPlots.map(plot => {
                    if (plot.name === name) {
                        return {...plot, index: plot.index-1}
                    } else {
                        return {...plot}
                    }
                }))
            }
        }

    }


    const modifyRecords = (searchInput: string, selectedHeader: string, updateHeaders?: Dispatch<SetStateAction<Header[]>>) => {
        const searchArr = searchItems(searchInput, stats, 'method_name');
        const sortedArr = sortByColumn(selectedHeader, headers, true, searchArr, updateHeaders);
        setModifiedStats(sortedArr);
    }

    return {
        modifiedStats,
        stats,
        modelName,
        modifyRecords,
        headers,
        setHeaders,
        barData,
        radarData,
        back,
        next,
        methodPlots
    }
};

export default useResultsetValidations;