/*
 * Decompiled with CFR 0.152.
 */
package net.algart.executors.modules.opencv.matrices.ml.training;

import java.util.Locale;
import net.algart.executors.api.data.SNumbers;
import net.algart.executors.modules.opencv.matrices.ml.MLKind;
import net.algart.executors.modules.opencv.matrices.ml.MLSamplesType;
import net.algart.executors.modules.opencv.matrices.ml.MLStatModelTrainer;
import net.algart.executors.modules.opencv.matrices.ml.training.MLTrainDTrees;
import net.algart.executors.modules.opencv.util.O2SMat;
import net.algart.executors.modules.opencv.util.OTools;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.TermCriteria;
import org.bytedeco.opencv.opencv_ml.DTrees;
import org.bytedeco.opencv.opencv_ml.RTrees;
import org.bytedeco.opencv.opencv_ml.StatModel;

public final class MLTrainRTrees
extends MLTrainDTrees {
    public static final String OUTPUT_VAR_IMPORTANCE = "var_importance";
    private int activeVarCount = 0;
    private boolean calculateVarImportance = false;
    private int terminationMaxCount = 0;
    private double terminationEpsilon = 0.0;

    private MLTrainRTrees(MLSamplesType inputType) {
        super(inputType);
        this.addOutputNumbers(OUTPUT_VAR_IMPORTANCE);
    }

    public static MLTrainRTrees newTrainNumbers() {
        return new MLTrainRTrees(MLSamplesType.NUMBERS);
    }

    public static MLTrainRTrees newTrainPixels() {
        return new MLTrainRTrees(MLSamplesType.PIXELS);
    }

    public int getActiveVarCount() {
        return this.activeVarCount;
    }

    public MLTrainRTrees setActiveVarCount(int activeVarCount) {
        this.activeVarCount = activeVarCount;
        return this;
    }

    public boolean isCalculateVarImportance() {
        return this.calculateVarImportance;
    }

    public MLTrainRTrees setCalculateVarImportance(boolean calculateVarImportance) {
        this.calculateVarImportance = calculateVarImportance;
        return this;
    }

    public int getTerminationMaxCount() {
        return this.terminationMaxCount;
    }

    public MLTrainRTrees setTerminationMaxCount(int terminationMaxCount) {
        this.terminationMaxCount = MLTrainRTrees.nonNegative((int)terminationMaxCount);
        return this;
    }

    public double getTerminationEpsilon() {
        return this.terminationEpsilon;
    }

    public MLTrainRTrees setTerminationEpsilon(double terminationEpsilon) {
        this.terminationEpsilon = MLTrainRTrees.nonNegative((double)terminationEpsilon);
        return this;
    }

    public void process() {
        block25: {
            try (RTrees model = this.newStatModel();
                 Mat priors = this.priors();
                 TermCriteria termCriteria = OTools.termCriteria(this.terminationMaxCount, this.terminationEpsilon, true);){
                model.setActiveVarCount(this.activeVarCount);
                model.setCalculateVarImportance(this.calculateVarImportance);
                if (termCriteria != null) {
                    model.setTermCriteria(termCriteria);
                }
                this.customizeRTrees((DTrees)model, priors);
                MLTrainRTrees.logDebug(() -> "Training " + this.modelKind().modelName() + ": " + MLTrainRTrees.toString(model));
                MLStatModelTrainer trainer = new MLStatModelTrainer((StatModel)model, this.modelKind());
                this.setTrainingFlags(trainer);
                this.train(trainer);
                this.writeTrainer(trainer);
                if (!this.calculateVarImportance) break block25;
                try (Mat varImportance = model.getVarImportance();){
                    this.getNumbers(OUTPUT_VAR_IMPORTANCE).exchange(O2SMat.toRawNumbers(varImportance, varImportance.cols() * varImportance.rows()));
                }
            }
        }
    }

    public static String toString(RTrees model) {
        return String.format(Locale.US, "%s, activeVarCount=%s, calculateVarImportance=%s, %s", MLTrainDTrees.toString((DTrees)model), model.getActiveVarCount(), model.getCalculateVarImportance(), OTools.toString(model.getTermCriteria()));
    }

    @Override
    protected MLKind modelKind() {
        return MLKind.StatModelBased.R_TREES;
    }

    @Override
    protected boolean categoricalResponses() {
        return true;
    }

    private RTrees newStatModel() {
        RTrees result = RTrees.create();
        MLTrainRTrees.logDebug(() -> "Creating RTrees: " + MLTrainRTrees.toString(result));
        return result;
    }

    public static void main(String[] args) {
        RTrees model = RTrees.create();
        SNumbers priors = SNumbers.ofArray((Object)new double[]{2.0, 1.0}, (int)1);
        model.setPriors(O2SMat.numbersToMulticolumnMat(priors));
        model.setCalculateVarImportance(true);
        MLTrainRTrees training = new MLTrainRTrees(MLSamplesType.NUMBERS);
        training.setUseGPU(false);
        training.trainNumbers(new MLStatModelTrainer((StatModel)model, MLKind.StatModelBased.R_TREES), SNumbers.ofArray((Object)new float[]{10.0f, 30.0f}, (int)1), SNumbers.ofArray((Object)new int[]{2, 3}, (int)1), null);
        System.out.println("OK: " + model.isClassifier());
        Mat varImportanceMat = model.getVarImportance();
        System.out.println(OTools.toString(varImportanceMat));
        System.out.println("varImportance: " + O2SMat.multicolumnMatToNumbers(varImportanceMat).toString(true));
        Mat priorsMat = model.getPriors();
        System.out.println(OTools.toString(priorsMat));
        System.out.println("priors: " + O2SMat.multicolumnMatToNumbers(priorsMat).toString(true));
        Mat samplesMat = O2SMat.numbersToMulticolumn32BitMat(SNumbers.ofArray((Object)new float[]{10.0f, 30.0f}, (int)1), false);
        Mat resultMat = new Mat();
        model.predict(samplesMat, resultMat, 512);
        System.out.println("prediction: " + OTools.toString(resultMat));
    }
}

