// components
import Field from "../../../../components/Form/Field";
import FlowCard from "../../../../components/Data/FlowCard";

// hooks
import { useState, useEffect, Dispatch, SetStateAction } from "react";

// types
import { ChangeEvent } from "react";
import { CreateModelForm, DescriptorType } from "../../../../types/models.types";
import Select from 'react-select';
import { Dataset } from "../../../../types/datasets.types";
import { FileType } from "../../../../types/files.types";
import { FieldMappingForm } from "../../../../types/all.types";



export const defaultForm: CreateModelForm = {
    ds_ids: [],
    name: '',
    label_fields: [],
    descriptors: [
        {
            provider: 'RDKit',
            name: 'ECFP',
            params: {
                Radius: '3',
                Bits: '1024',
            }
        }
    ],
    methods: [],
    hyper_params: {
        DL: {
        dropout: 0.5,
        beta: 0.01,
        }
    },
    test_set_size: 0,
    conf_pred: false,
    resample_method: null,

}

const descriptorOptions = [
    {value: 'ECFP', label: 'ECFP'},
    {value: 'FCFP', label: 'FCFP'},
]

const bitOptions = [
    {value: '512', label: '512'},
    {value: '1024', label: '1024'},
    {value: '2048', label: '2048'}
]

const radiusOptions = [
    {value: '2', label: '2'},
    {value: '3', label: '3'},
    {value: '4', label: '4'},
];

const classificationMethods = [
    {value: 'DL', label: 'Deep Learning'},
    {value: 'ada', label: 'AdaBoostClassifier'},
    {value: 'bnb', label: 'BernoulliNB'},
    {value: 'knn', label: 'KNeighborsClassifier'},
    {value: 'lreg', label: 'LogisticRegression'},
    {value: 'rf', label: 'RandomForestClassifier'},
    {value: 'svc', label: 'SVC'},
    {value: 'xgb', label: 'XGBClassifier'},
];

const regressionMethods = [
    {value: 'adar', label: 'AdaBoosterRegressor'},
    {value: 'br', label: 'BayesianRidge'},
    {value: 'enr', label: 'ElasticNet'},
    {value: 'knnr', label: 'KNeighborsRegressor'},
    {value: 'rfr', label: 'RandomForestRegressor'},
    {value: 'svr', label: 'SVR'},
    {value: 'xgbr', label: 'XGBRegressor'},
]

const allMethods = [
    {value: 'DL', label: 'Deep Learning'},
    {value: 'ada', label: 'AdaBoostClassifier'},
    {value: 'bnb', label: 'BernoulliNB'},
    {value: 'knn', label: 'KNeighborsClassifier'},
    {value: 'lreg', label: 'LogisticRegression'},
    {value: 'rf', label: 'RandomForestClassifier'},
    {value: 'svc', label: 'SVC'},
    {value: 'xgb', label: 'XGBClassifier'},
    {value: 'adar', label: 'AdaBoosterRegressor'},
    {value: 'br', label: 'BayesianRidge'},
    {value: 'enr', label: 'ElasticNet'},
    {value: 'knnr', label: 'KNeighborsRegressor'},
    {value: 'rfr', label: 'RandomForestRegressor'},
    {value: 'svr', label: 'SVR'},
    {value: 'xgbr', label: 'XGBRegressor'},
]

const samplingTypes = [
    {value: null, label: 'None'},
    {value: 'Undersampling', label: 'Undersampling'},
    {value: 'Oversampling', label: 'Oversampling'},
]

const underSamplingMethods = [
    {value: 'tomek_links', label: 'Tomek Links'},
    {value: 'random_under_sampler', label: 'Random Under Sampler'},
];

const overSamplingMethods = [
    {value: 'adasyn', label: 'ADASYN'},
    {value: 'smote', label: 'SMOTE'},
    {value: 'random_over_sampler', label: 'Random Over Sampler'},
]


export default function SetTrainingParameters({
    number, 
    selectedDataset, 
    selectedFile,
    form, 
    setForm, 
    fieldsMapping,
}: {
    number: number, 
    selectedDataset: Dataset | null, 
    selectedFile?: FileType | null,
    form: CreateModelForm, 
    setForm: Dispatch<SetStateAction<CreateModelForm>>,
    fieldsMapping?: FieldMappingForm,
}) {
    const [ isAuto, setIsAuto ] = useState(true);
    const [ valueType, setValueType ] = useState('');
    const [ values, setValues ] = useState<{value: string, label: string}[]>([]);
    const [ descriptorList, setDescriptorList ] = useState<string[]>(['ECFP']);
    const [ ECFPParams, setECFPParams ] = useState<DescriptorType | null>(null);
    const [ FCFPParams, setFCFPParams ] = useState<DescriptorType | null>(null);
    const [ sampler, setSampler ] = useState<{value: string | null, label: string}>({value: null, label: 'None'});
    const [ samplerType, setSamplerType ] = useState<{value: string, label: string} | null>(null);
    const [ samplingMethods, setSamplingMethods ] = useState<{value: string, label: string}[]>([]);


    const handleInput = (e:ChangeEvent<HTMLInputElement>) => {
        setForm({...form, [e.target.name]: e.target.value})
    };

    const handleDLInput = (e: ChangeEvent<HTMLInputElement>) => {
        setForm({
            ...form, 
            hyper_params: {
                DL: {
                    ...form.hyper_params.DL,
                    [e.target.name]: e.target.value,
                }
            }
        })
    }

    const handleMultiSelect = (e:any, name: string) => {
        if (name === 'descriptors') {
            setForm({...form, [name]: e.map((obj:any) => ({
                provider: 'RDKit',
                name: obj.value,
                params: {
                    Radius: '3',
                    Bits: '1024',
                }
            }))});
            const list = e.map((obj:any) => obj.value)
            setDescriptorList(list);

            setECFPParams(list.includes('ECFP') ? {
                provider: 'RDKit',
                name: 'ECFP',
                params: {
                    Radius: '3',
                    Bits: '1024',
                }
            } : null);

            setFCFPParams(list.includes('FCFP') ? {
                provider: 'RDKit',
                name: 'FCFP',
                params: {
                    Radius: '3',
                    Bits: '1024',
                }
            } : null);
            
            

        } else {
            setForm({...form, [name]: e.map((obj:any) => obj.value)})
        }
        
    }


    const handleSelectDescriptorParams = (e:any, descriptor: string, name: string) => {

        setForm({
            ...form,
            descriptors: form.descriptors.map(d => {
                if (d.name === descriptor) {
                    return {...d, params: {...d.params, [name as keyof object]: e.value}}
                } else {
                    return {...d}
                }
            })
        })

        if (descriptor === 'ECFP' && ECFPParams) {
            setECFPParams({...ECFPParams, params: {...ECFPParams.params, [name as keyof object]: e.value}})
        } else if (descriptor === 'FCFP' && FCFPParams) {
            setFCFPParams({...FCFPParams, params: {...FCFPParams.params, [name as keyof object]: e.value}})
        }
    }

    const handleSamplingType = (e:any) => {
        if (e.value === null) {
            setSamplingMethods([]);
        } else if (e.value === 'Undersampling') {
            setSamplingMethods(underSamplingMethods);
        } else if (e.value === 'Oversampling') {
            setSamplingMethods(overSamplingMethods);
        }
        setSampler(e);
        setSamplerType(null);
        setForm({...form, resample_method: null})
    };

    const handleSamplingMethod = (e:any) => {
        setForm({...form, resample_method: e.value});
        setSamplerType(e);
    }

    const handleCheck = (e: ChangeEvent<HTMLInputElement>) => {
        setForm({...form, [e.target.name]: e.target.checked})
        
    }

    


    useEffect(() => {
        if (!selectedDataset) {
            if (!selectedFile) {
                setForm(defaultForm);
                setValueType('');
                setFCFPParams(null);
                setECFPParams(null);
            } else {
                if (fieldsMapping && fieldsMapping['continuous-value'] && !fieldsMapping['single-class-label']) {
                    setValueType('continuous-value');
                    setForm({...form, label_fields: ['continuous-value'], name: selectedFile ? selectedFile.name.split('.')[0] : '', methods: regressionMethods.map(method => method.value)})
                } else if (fieldsMapping && !fieldsMapping['continuous-value'] && (fieldsMapping['single-class-label'] || fieldsMapping['split-on-value'])) {
                    setValueType('single-class-label');
                    setForm({...form, label_fields: ['single-class-label'], name: selectedFile ? selectedFile.name.split('.')[0] : '', methods: classificationMethods.map(method => method.value)})
                } else {
                    setForm({...form, name: selectedFile ? selectedFile.name.split('.')[0] : ''})
                }
                
                setECFPParams({
                    provider: 'RDKit',
                    name: 'ECFP',
                    params: {
                        Radius: '3',
                        Bits: '1024',
                    }
                })
            }  
        } else {
            const valueObj = selectedDataset.fields_mapping.find(m => ['single-class-label', 'continuous-value'].includes(m.type));
            if (valueObj) {
                setValueType(valueObj.type)
                if (valueObj.type === 'single-class-label') {
                    setForm({...form, ds_ids: [selectedDataset._id.$oid], label_fields: [valueObj.type], name: selectedDataset.name, methods: classificationMethods.map(method => method.value)})
                } else {
                    setForm({...form, ds_ids: [selectedDataset._id.$oid], label_fields: [valueObj.type], name: selectedDataset.name, methods: regressionMethods.map(method => method.value)})
                }
            } else {
                setForm({...form, name: selectedDataset.name, ds_ids: [selectedDataset._id.$oid]})
            }

            setECFPParams({
                provider: 'RDKit',
                name: 'ECFP',
                params: {
                    Radius: '3',
                    Bits: '1024',
                }
            })
        }
        //eslint-disable-next-line
    }, [selectedDataset, fieldsMapping]);




    useEffect(() => {
        if (form.methods.length) {
            const arr: {value: string, label: string}[] = [];
            form.methods.forEach(method => {
                const obj = allMethods.find(m => m.value === method);
                if (obj) {
                    arr.push({value: method, label: obj.label})
                }
            })
            setValues(arr);

        }
    }, [form.methods])


    return (
        <FlowCard label="Set Training Parameters" number={number} isLocked={selectedFile !== undefined ? selectedFile ? false : true : selectedDataset ? false : true}>
            <div className="flex flex-col gap-6 justify-start items-stretch">

                <Field label='Model Name' isRequired={true}>
                    <div className='rounded w-full flex justify-between items-stretch'>
                        <input onChange={handleInput} name='name' value={form.name} className='border-l border-t border-b border-primary  rounded-tl rounded-bl text-left text-[14px] p-2 grow' placeholder='Model Name'/>
                        <label className={`px-3 border ${isAuto ? 'border-secondary' : 'border-primary'} cursor-pointer flex items-center gap-2 rounded-tr rounded-br relative`}>
                            {isAuto && <div className="absolute top-0 left-0 right-0 bottom-0 bg-secondary bg-opacity-10" />}
                            <input onChange={handleCheck} type='checkbox' checked={isAuto} className='accent-secondary'/>
                            <p className={`text-[12px] ${isAuto ? 'font-medium text-secondary' : 'font-regular text-primary'}`}>Auto-generate</p>
                        </label>
                    </div>    
                </Field>

                <Field label='Descriptors' isRequired={true}>
                    <Select onChange={(e:any) => handleMultiSelect(e, 'descriptors')} isMulti options={descriptorOptions} value={descriptorList.map(d => ({value: d, label: d}))} className='text-left text-[14px] rounded w-full' placeholder='Descriptors'/>
                </Field>

                <div className="flex justify-between items-center gap-6">
                    <div className="w-1/2">
                        <Field label='Dropout rate (0-1)' isRequired={true}>
                            <input onChange={handleDLInput} name='dropout' value={form.hyper_params.DL.dropout}  className='rounded border border-primary text-left text-[14px] p-2 w-full' placeholder='Dropout rate'/>
                        </Field>
                    </div>

                    <div className="w-1/2">
                        <Field label='L2 Regularization Factor' isRequired={true}>
                            <input onChange={handleDLInput} name='beta' value={form.hyper_params.DL.beta} className='rounded border border-primary text-left text-[14px] p-2 w-full' placeholder='L2'/>
                        </Field>
                    </div>
                </div>

                {descriptorList.includes('ECFP') && (
                <div className="flex justify-between items-center gap-6">
                    <div className="w-1/2">
                        <Field label='ECFP Bits' isRequired={true}>
                            <Select onChange={(e:any) => handleSelectDescriptorParams(e, 'ECFP', 'Bits')} options={bitOptions} value={ECFPParams ? {value: ECFPParams.params.Bits, label: ECFPParams.params.Bits} : null} className='text-left text-[14px] rounded w-full' placeholder='Bits'/>
                        </Field>
                    </div>

                    <div className="w-1/2">
                        <Field label='ECFP Radius' isRequired={true}>
                            <Select onChange={(e:any) => handleSelectDescriptorParams(e, 'ECFP', 'Radius')} options={radiusOptions}  value={ECFPParams ? {value: ECFPParams.params.Radius, label: ECFPParams.params.Radius} : null} className='text-left text-[14px] rounded w-full' placeholder='Radius'/>
                        </Field>
                    </div>
                </div>
                )}

                {descriptorList.includes('FCFP') && (
                <div className="flex justify-between items-center gap-6">
                    <div className="w-1/2">
                        <Field label='FCFP Bits' isRequired={true}>
                            <Select onChange={(e:any) => handleSelectDescriptorParams(e, 'FCFP', 'Bits')}  options={bitOptions}  value={FCFPParams ? {value: FCFPParams.params.Bits, label: FCFPParams.params.Bits} : null} className='text-left text-[14px] rounded w-full' placeholder='Bits'/>
                        </Field>
                    </div>

                    <div className="w-1/2">
                        <Field label='FCFP Radius' isRequired={true}>
                            <Select onChange={(e:any) => handleSelectDescriptorParams(e, 'FCFP', 'Radius')} options={radiusOptions} value={FCFPParams ? {value: FCFPParams.params.Radius, label: FCFPParams.params.Radius} : null} className='text-left text-[14px] rounded w-full' placeholder='Radius'/>
                        </Field>
                    </div>
                </div>
                )}

                <div className="flex flex-col gap-2 justify-start items-start">
                    <Field label='Methods' isRequired={true}>
                        {!valueType ? (
                            <Select className='text-left text-[14px] rounded w-full' placeholder='Methods'/>
                        ) : valueType === 'single-class-label' ? (
                            <Select onChange={(e:any) => handleMultiSelect(e, 'methods')} isMulti options={classificationMethods} value={values} className='text-left text-[14px] rounded w-full' placeholder='Methods'/>
                        ) : (
                            <Select onChange={(e:any) => handleMultiSelect(e, 'methods')} isMulti options={regressionMethods} value={values} className='text-left text-[14px] rounded w-full' placeholder='Methods'/>
                        )}
                    </Field>
                    {valueType && valueType === 'single-class-label' && (
                        <label className="flex gap-2">
                            <input className='accent-secondary' name='conf_pred' onChange={handleCheck} type='checkbox' checked={form.conf_pred} />
                            <label className="text-[12px]">Train model with conformal predictor version of each method</label>
                        </label>
                    )}
                </div>

                {valueType && valueType === 'single-class-label' && (
                <div id='sampling-section' className="flex justify-between items-center gap-6">
                    <div className="w-1/2">
                        <Field label='Sampling Type' isRequired={false}>
                            <Select maxMenuHeight={100} menuPortalTarget={document.body} onChange={(e:any) => handleSamplingType(e)} options={samplingTypes} value={sampler} className='text-left text-[14px] rounded w-full' placeholder='Sampling Type'/>
                        </Field>
                    </div>
                    <div className="w-1/2">
                        <Field label='Sampling Method' isRequired={false}>
                            <Select maxMenuHeight={100} menuPortalTarget={document.body} isDisabled={sampler.label === 'None' ? true : false} onChange={(e:any) => handleSamplingMethod(e)} options={samplingMethods} value={samplerType} className='text-left text-[14px] rounded w-full' placeholder={`${sampler.label === 'None' ? 'Not Applicable' : 'Method'}`}/>
                        </Field>
                    </div>
                
                </div>
                )}

            </div>
        </FlowCard>
    )
}