/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml;

import net.algart.executors.api.data.Data;
import net.algart.executors.api.data.DataType;
import net.algart.executors.api.data.SMat;
import net.algart.executors.api.data.SNumbers;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLPredict;
import net.algart.executors.modules.opencv.matrices.ml.AbstractMLTrain;
import net.algart.executors.modules.opencv.matrices.ml.MLPredictor;
import net.algart.executors.modules.opencv.matrices.ml.MLTrainer;
import net.algart.executors.modules.opencv.matrices.ml.prediction.MLPredict;

public enum MLSamplesType {
    NUMBERS(DataType.NUMBERS){

        @Override
        void train(AbstractMLTrain executor, MLTrainer trainer) {
            SNumbers numbers = executor.getInputNumbers("samples");
            int blockLength = MLSamplesType.getBlockLengthAndCheckForCombinedSamplesAndResponses(numbers);
            boolean combined = executor.isTrainingCombinedSamplesAndResponses();
            SNumbers samples = combined ? numbers.columnRange(0, blockLength - 1) : numbers;
            SNumbers responses = combined ? numbers.columnRange(blockLength - 1, 1) : executor.getInputNumbers("training_responses");
            SNumbers autoTestResult = new SNumbers();
            double error = executor.trainNumbers(trainer, samples, responses, autoTestResult);
            if (autoTestResult.isInitialized()) {
                executor.getScalar("error").setTo(error);
            }
            if (executor.isTestPredictTrainedSamples()) {
                autoTestResult.exchange(MLPredict.predict((MLPredictor)trainer, samples, executor.isConvertCategoricalResponses(), executor.isUseGPU()));
            }
            if (autoTestResult.isInitialized()) {
                executor.getNumbers().exchange(autoTestResult);
            }
        }

        @Override
        void predict(AbstractMLPredict executor, MLPredictor predictor) {
            SNumbers samples = executor.getInputNumbers("samples");
            executor.getNumbers().setTo(executor.predictNumbers(predictor, samples));
        }
    }
    ,
    PIXELS(DataType.MAT){

        @Override
        void train(AbstractMLTrain executor, MLTrainer trainer) {
            SMat samples = executor.getInputMat("samples");
            SMat responses = executor.getInputMat("training_responses");
            SMat autoTestResult = new SMat();
            double error = executor.trainPixels(trainer, samples, responses, autoTestResult);
            if (autoTestResult.isInitialized()) {
                executor.getScalar("error").setTo(error);
            }
            if (executor.isTestPredictTrainedSamples()) {
                autoTestResult.exchange((Data)MLPredict.predict((MLPredictor)trainer, samples, executor.isConvertCategoricalResponses(), executor.isUseGPU()));
            }
            if (autoTestResult.isInitialized()) {
                executor.getMat().exchange((Data)autoTestResult);
            }
        }

        @Override
        void predict(AbstractMLPredict executor, MLPredictor predictor) {
            SMat samples = executor.getInputMat("samples");
            executor.getMat().setTo(executor.predictPixels(predictor, samples));
        }
    };

    final DataType portDataType;

    private MLSamplesType(DataType portDataType) {
        this.portDataType = portDataType;
    }

    abstract void train(AbstractMLTrain var1, MLTrainer var2);

    abstract void predict(AbstractMLPredict var1, MLPredictor var2);

    private static int getBlockLengthAndCheckForCombinedSamplesAndResponses(SNumbers samples) {
        int blockLength = samples.getBlockLength();
        if (blockLength <= 1) {
            throw new IllegalArgumentException("Input samples must contain more than 1 column: the last column must contain training responses");
        }
        return blockLength;
    }
}

