## Copyright (C) 2024 Ruchika Sonagote <ruchikasonagote2003@gmail.com>
## Copyright (C) 2024 Pallav Purbia <pallavpurbia@gmail.com>
## Copyright (C) 2024-2025 Andreas Bertsatos <abertsatos@biol.uoa.gr>
##
## This file is part of the statistics package for GNU Octave.
##
## This program is free software; you can redistribute it and/or modify it under
## the terms of the GNU General Public License as published by the Free Software
## Foundation; either version 3 of the License, or (at your option) any later
## version.
##
## This program is distributed in the hope that it will be useful, but WITHOUT
## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
## details.
##
## You should have received a copy of the GNU General Public License along with
## this program; if not, see <http://www.gnu.org/licenses/>.

classdef ClassificationPartitionedModel
## -*- texinfo -*-
## @deftp {statistics} ClassificationPartitionedModel
##
## Cross-validated classification model
##
## The @code{ClassificationPartitionedModel} class stores cross-validated
## classification models trained on different partitions of the data.
## It can predict responses for observations not used for training using
## the @code{kfoldPredict} method.
##
## Create a @code{ClassificationPartitionedModel} object by using the
## @code{crossval} function.
##
## @seealso{crossval}
## @end deftp

  properties
    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} BinEdges
    ##
    ## Bin edges
    ##
    ## A cell array specifying the bin edges for binned predictors.
    ## This property is read-only.
    ##
    ## @end deftp
    BinEdges                     = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} CategoricalPredictors
    ##
    ## Indices of categorical predictors
    ##
    ## A vector of positive integers specifying the indices of categorical
    ## predictors.  This property is read-only.
    ##
    ## @end deftp
    CategoricalPredictors        = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} X
    ##
    ## Predictor data
    ##
    ## A numeric matrix containing the unstandardized predictor data.  Each
    ## column of @var{X} represents one predictor (variable), and each row
    ## represents one observation.  This property is read-only.
    ##
    ## @end deftp
    X                            = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Y
    ##
    ## Class labels
    ##
    ## Specified as a logical or numeric column vector, or as a character array
    ## or a cell array of character vectors with the same number of rows as the
    ## predictor data.  Each row in @var{Y} is the observed class label for
    ## the corresponding row in @var{X}.  This property is read-only.
    ##
    ## @end deftp
    Y                            = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} ClassNames
    ##
    ## Names of classes in the response variable
    ##
    ## An array of unique values of the response variable @var{Y}, which has the
    ## same data types as the data in @var{Y}.  This property is read-only.
    ## @qcode{ClassNames} can have any of the following datatypes:
    ##
    ## @itemize
    ## @item Cell array of character vectors
    ## @item Character array
    ## @item Logical vector
    ## @item Numeric vector
    ## @end itemize
    ##
    ## @end deftp
    ClassNames                   = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Cost
    ##
    ## Cost of Misclassification
    ##
    ## A square matrix specifying the cost of misclassification of a point.
    ## @qcode{Cost(i,j)} is the cost of classifying a point into class @qcode{j}
    ## if its true class is @qcode{i} (that is, the rows correspond to the true
    ## class and the columns correspond to the predicted class).  The order of
    ## the rows and columns in @qcode{Cost} corresponds to the order of the
    ## classes in @qcode{ClassNames}.  The number of rows and columns in
    ## @qcode{Cost} is the number of unique classes in the response.  By
    ## default, @qcode{Cost(i,j) = 1} if @qcode{i != j}, and
    ## @qcode{Cost(i,j) = 0} if @qcode{i = j}.  In other words, the cost is 0
    ## for correct classification and 1 for incorrect classification.
    ##
    ## @end deftp
    Cost                         = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} CrossValidatedModel
    ##
    ## Cross-validated model class
    ##
    ## A character vector specifying the class of the cross-validated model.
    ## This field contains the type of model that was used for the training,
    ## e.g., @qcode{"ClassificationKNN"}.  This property is read-only.
    ##
    ## @end deftp
    CrossValidatedModel          = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} KFold
    ##
    ## Number of cross-validated folds
    ##
    ## A positive integer value specifying the number of cross-validated folds.
    ## This property is read-only.
    ##
    ## @end deftp
    KFold                        = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} ModelParameters
    ##
    ## Model parameters
    ##
    ## A structure containing the model parameters used during training.
    ## This includes any model-specific parameters that were configured prior
    ## to training.  This property is read-only.
    ##
    ## @end deftp
    ModelParameters              = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} NumObservations
    ##
    ## Number of observations
    ##
    ## A positive integer value specifying the number of observations in the
    ## training dataset used for training the cross-validated model.
    ## This property is read-only.
    ##
    ## @end deftp
    NumObservations              = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Partition
    ##
    ## Partition configuration
    ##
    ## A @code{cvpartition} object specifying the partition configuration used
    ## for cross-validation.  This field stores the cvpartition instance that
    ## describes how the data was split into training and validation sets.
    ## This property is read-only.
    ##
    ## @end deftp
    Partition                    = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} PredictorNames
    ##
    ## Names of predictor variables
    ##
    ## A cell array of character vectors specifying the names of the predictor
    ## variables.  The names are in the order in which the appear in the
    ## training dataset.  This property is read-only.
    ##
    ## @end deftp
    PredictorNames               = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Prior
    ##
    ## Prior probability for each class
    ##
    ## A numeric vector specifying the prior probabilities for each class.  The
    ## order of the elements in @qcode{Prior} corresponds to the order of the
    ## classes in @qcode{ClassNames}.  This property is read-only.
    ##
    ## @end deftp
    Prior                        = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} ResponseName
    ##
    ## Response variable name
    ##
    ## A character vector specifying the name of the response variable @var{Y}.
    ## This property is read-only.
    ##
    ## @end deftp
    ResponseName                 = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} ScoreTransform
    ##
    ## Transformation function for classification scores
    ##
    ## Specified as a function handle for transforming the classification
    ## scores.  This property is read-only.
    ##
    ## @end deftp
    ScoreTransform               = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Standardize
    ##
    ## Standardize predictors flag
    ##
    ## A logical scalar specifying whether to standardize the predictors.
    ## This property is read-only.
    ##
    ## @end deftp
    Standardize                  = [];

    ## -*- texinfo -*-
    ## @deftp {ClassificationPartitionedModel} {property} Trained
    ##
    ## Models trained on each fold
    ##
    ## A cell array of models trained on each fold.  Each cell contains a model
    ## trained on the minus-one fold of the data (all but one fold used for
    ## training and the remaining fold used for validation).  This property is
    ## read-only.
    ##
    ## @end deftp
    Trained                      = [];
  endproperties

  properties (Access = private, Hidden)
    STname = 'none';
  endproperties

  methods (Access = public)
    ## -*- texinfo -*-
    ## @deftypefn  {ClassificationPartitionedModel} {@var{this} =} ClassificationPartitionedModel (@var{Mdl}, @var{Partition})
    ##
    ## Create a @code{ClassificationPartitionedModel} class object for cross-validation
    ## of classification models.
    ##
    ## @code{@var{this} = ClassificationPartitionedModel (@var{Mdl},
    ## @var{Partition})} returns a ClassificationPartitionedModel object, with
    ## @var{Mdl} as the trained classification model object and
    ## @var{Partition} as the partitioning object obtained using @code{cvpartition}
    ## function.
    ##
    ## @seealso{cvpartition}
    ## @end deftypefn
    function this = ClassificationPartitionedModel (Mdl, Partition)

      ## Check input arguments
      if (nargin < 2)
        error ("ClassificationPartitionedModel: too few input arguments.");
      endif

      ## Check for valid Classification object
      validTypes = {'ClassificationDiscriminant', 'ClassificationGAM', ...
                    'ClassificationKNN', 'ClassificationNeuralNetwork', ...
                    'ClassificationSVM'};
      if (! any (strcmp (class (Mdl), validTypes)))
        error ("ClassificationPartitionedModel: unsupported model type.");
      endif

      ## Check for valid cvpartition object
      if (! strcmp (class (Partition), "cvpartition"))
        error ("ClassificationPartitionedModel: invalid 'cvpartition' object.");
      endif

      ## Set properties
      this.X = Mdl.X;
      this.Y = Mdl.Y;
      this.KFold = Partition.NumTestSets;
      this.Trained = cell (this.KFold, 1);
      this.ClassNames = Mdl.ClassNames;
      this.ResponseName = Mdl.ResponseName;
      this.NumObservations = Mdl.NumObservations;
      this.PredictorNames = Mdl.PredictorNames;
      this.Partition = Partition;
      this.CrossValidatedModel = class (Mdl);
      this.ScoreTransform = Mdl.ScoreTransform;
      this.STname = Mdl.STname;
      if (ismember (class (Mdl), validTypes(1:3)))
        this.Prior = Mdl.Prior;
        this.Cost = Mdl.Cost;
      endif
      if (ismember (class (Mdl), validTypes(3:5)))
        this.Standardize = Mdl.Standardize;
      endif

      ## Switch Classification object types
      switch (this.CrossValidatedModel)

        case "ClassificationDiscriminant"
          ## Arguments to pass in fitcdiscr
          args = {};
          ## List of acceptable parameters for fitcdiscr
          DiscrParams = {'PredictorNames', 'ResponseName', 'ClassNames', ...
                         'Cost', 'DiscrimType', 'Gamma'};
          ## Set parameters
          for i = 1:numel (DiscrParams)
            paramName = DiscrParams{i};
            paramValue = Mdl.(paramName);
            if (! isempty (paramValue))
              args = [args, {paramName, paramValue}];
            endif
          endfor
          ## Add 'FillCoeffs' parameter
          if (isempty (Mdl.Coeffs))
            args = [args, {'FillCoeffs', 'off'}];
          endif

          ## Train model according to partition object
          for k = 1:this.KFold
            idx = training (this.Partition, k);
            tmp = fitcdiscr (this.X(idx, :), this.Y(idx,:), args{:});
            this.Trained{k} = compact (tmp);
          endfor

          ## Store ModelParameters to ClassificationPartitionedModel object
          params = struct();
          paramList = {'DiscrimType', 'FillCoeffs', 'Gamma'};
          for i = 1:numel (paramList)
            paramName = paramList{i};
            if (isprop (Mdl, paramName))
              params.(paramName) = Mdl.(paramName);
            endif
          endfor
          this.ModelParameters = params;

        case "ClassificationGAM"
          ## Arguments to pass in fitcgam
          args = {};
          ## List of acceptable parameters for fitcdiscr
          GAMparams = {'PredictorNames', 'ResponseName', 'ClassNames', ...
                       'Cost', 'Formula', 'Interactions', 'Knots', 'Order', ...
                       'LearningRate', 'NumIterations'};
          ## Set parameters
          for i = 1:numel (GAMparams)
            paramName = GAMparams{i};
            paramValue = Mdl.(paramName);
            if (! isempty (paramValue))
              args = [args, {paramName, paramValue}];
            endif
          endfor

          ## Train model according to partition object
          for k = 1:this.KFold
            idx = training (this.Partition, k);
            tmp = fitcgam (this.X(idx, :), this.Y(idx,:), args{:});
            this.Trained{k} = compact (tmp);
          endfor

          ## Store ModelParameters to ClassificationPartitionedModel object
          params = struct();
          paramList = {'Formula', 'Interactions', 'Knots', 'Order', 'DoF', ...
                       'LearningRate', 'NumIterations'};
          for i = 1:numel (paramList)
            paramName = paramList{i};
            if (isprop (Mdl, paramName))
              params.(paramName) = Mdl.(paramName);
            endif
          endfor
          this.ModelParameters = params;

        case 'ClassificationKNN'
          ## Arguments to pass in fitcknn
          args = {};
          ## List of acceptable parameters for fitcknn
          KNNparams = {'PredictorNames', 'ResponseName', 'ClassNames', ...
                       'Prior', 'Cost', 'ScoreTransform', 'BreakTies', ...
                       'NSMethod', 'BucketSize', 'NumNeighbors', 'Exponent', ...
                       'Scale', 'Cov', 'Distance', 'DistanceWeight', ...
                       'IncludeTies'};
          ## Set parameters
          for i = 1:numel (KNNparams)
            paramName = KNNparams{i};
            if (isprop (Mdl, paramName))
              paramValue = Mdl.(paramName);
              if (! isempty (paramValue))
                args = [args, {paramName, paramValue}];
              endif
            else
              switch (paramName)
                case 'Cov'
                  if (strcmpi (Mdl.Distance, 'mahalanobis') && ...
                      (! isempty (Mdl.DistParameter)))
                    args = [args, {'Cov', Mdl.DistParameter}];
                  endif
                case 'Exponent'
                  if (strcmpi (Mdl.Distance,'minkowski') && ...
                      (! isempty (Mdl.DistParameter)))
                    args = [args, {'Exponent', Mdl.DistParameter}];
                  endif
                case 'Scale'
                  if (strcmpi (Mdl.Distance,'seuclidean') && ...
                      (! isempty (Mdl.DistParameter)))
                    args = [args, {'Scale', Mdl.DistParameter}];
                  endif
              endswitch
            endif
          endfor

          ## Train model according to partition object
          for k = 1:this.KFold
            idx = training (this.Partition, k);
            this.Trained{k} = fitcknn (this.X(idx, :), this.Y(idx,:), args{:});
          endfor

          ## Store ModelParameters to ClassificationPartitionedModel object
          params = struct();
          paramList = {'NumNeighbors', 'Distance', 'DistParameter', ...
                       'NSMethod', 'DistanceWeight', 'Standardize'};
          for i = 1:numel (paramList)
            paramName = paramList{i};
            if (isprop (Mdl, paramName))
              params.(paramName) = Mdl.(paramName);
            endif
          endfor
          this.ModelParameters = params;

        case 'ClassificationNeuralNetwork'
          ## Arguments to pass in fitcnet
          args = {};
          ## List of acceptable parameters for fitcnet
          NNparams = {'PredictorNames', 'ResponseName', 'ClassNames', ...
                      'ScoreTransform', 'Standardize', 'LayerSizes', ...
                      'Activations', 'OutputLayerActivation', ...
                      'LearningRate', 'IterationLimit', 'DisplayInfo'};
          ## Set parameters
          for i = 1:numel (NNparams)
            paramName = NNparams{i};
            paramValue = Mdl.(paramName);
            if (! isempty (paramValue))
              args = [args, {paramName, paramValue}];
            endif
          endfor

          ## Train model according to partition object
          for k = 1:this.KFold
            idx = training (this.Partition, k);
            tmp = fitcnet (this.X(idx, :), this.Y(idx,:), args{:});
            this.Trained{k} = compact (tmp);
          endfor

          ## Store ModelParameters to ClassificationPartitionedModel object
          params = struct();
          paramList = {'LayerSizes', 'Activations', 'OutputLayerActivation', ...
                       'LearningRate', 'IterationLimit', 'Solver'};
          for i = 1:numel (paramList)
            paramName = paramList{i};
            if (isprop (Mdl, paramName))
              params.(paramName) = Mdl.(paramName);
            endif
          endfor
          this.ModelParameters = params;

        case 'ClassificationSVM'
          ## Get ModelParameters structure from ClassificationKNN object
          params = Mdl.ModelParameters;

          ## Train model according to partition object
          for k = 1:this.KFold
            idx = training (this.Partition, k);
            ## Pass all arguments directly to fitcsvm
            tmp = fitcsvm (this.X(idx, :), this.Y(idx,:), ...
                           'Standardize', Mdl.Standardize, ...
                           'PredictorNames', Mdl.PredictorNames, ...
                           'ResponseName', Mdl.ResponseName, ...
                           'ClassNames', Mdl.ClassNames, ...
                           'SVMtype', params.SVMtype, ...
                           'KernelFunction', params.KernelFunction, ...
                           'PolynomialOrder', params.PolynomialOrder, ...
                           'KernelScale', params.KernelScale, ...
                           'KernelOffset', params.KernelOffset, ...
                           'BoxConstraint', params.BoxConstraint, ...
                           'Nu', params.Nu, ...
                           'CacheSize', params.CacheSize, ...
                           'Tolerance', params.Tolerance, ...
                           'Shrinking', params.Shrinking);
            this.Trained{k} = compact (tmp);
          endfor

          ## Store ModelParameters to ClassificationPartitionedModel object
          this.ModelParameters = params;

      endswitch
    endfunction

    ## -*- texinfo -*-
    ## @deftypefn  {ClassificationPartitionedModel} {@var{label} =} kfoldPredict (@var{this})
    ## @deftypefnx {ClassificationPartitionedModel} {[@var{label}, @var{score}, @var{cost}] =} kfoldPredict (@var{this})
    ##
    ## Predict responses for observations not used for training in a
    ## cross-validated classification model.
    ##
    ## @code{@var{[label, Score, Cost]} = kfoldPredict (@var{this})}
    ## returns the predicted class labels, classification scores, and
    ## classification costs for the data used
    ## to train the cross-validated model @var{this}.
    ##
    ## @var{this} is a @code{ClassificationPartitionedModel} object.
    ## The function predicts the response for each observation that was
    ## held out during training in the cross-validation process.
    ##
    ## @multitable @columnfractions 0.28 0.02 0.7
    ## @headitem @var{Output} @tab @tab @var{Description}
    ##
    ## @item @qcode{label} @tab @tab Predicted class labels, returned as a
    ## vector or cell array. The type of @var{label} matches the type of
    ## @var{Y} in the original training data. Each element of @var{label}
    ## corresponds to the predicted class
    ## label for the corresponding row in @var{X}.
    ##
    ## @item @qcode{Score} @tab @tab Classification scores, returned as a
    ## numeric matrix. Each row of @var{Score} corresponds to an observation,
    ## and each column corresponds to a class. The value in row @var{i} and
    ## column @var{j} is the
    ## classification score for class @var{j} for observation @var{i}.
    ##
    ## @item @qcode{Cost} @tab @tab Classification costs, returned as a
    ## numeric matrix. Each row of @var{Cost} corresponds to an observation,
    ## and each column corresponds to a class. The value in row @var{i}
    ## and column @var{j} is the classification cost for class @var{j} for
    ## observation @var{i}. This output is optional and only returned if
    ## requested.
    ## @end multitable
    ##
    ## @seealso{ClassificationKNN, ClassificationSVM,
    ## ClassificationPartitionedModel}
    ## @end deftypefn
    function [label, Score, Cost] = kfoldPredict (this)

      ## Input validation
      no_cost_models = {'ClassificationNeuralNetwork', 'ClassificationSVM'};
      no_cost = any (strcmp (this.CrossValidatedModel, no_cost_models));
      if (no_cost && nargout > 2)
        error (strcat ("ClassificationPartitionedModel.kfoldPredict:", ...
                       " 'Cost' output is not supported for %s cross", ...
                       " validated models."), this.CrossValidatedModel);
      endif

      ## Initialize the label vector based on the type of Y
      if (iscellstr (this.Y))
        label = cell (this.NumObservations, 1);
      elseif (islogical (this.Y))
        label = false (this.NumObservations, 1);
      elseif (isnumeric (this.Y))
        label = zeros (this.NumObservations, 1);
      elseif (ischar (this.Y))
        label = char (zeros (this.NumObservations, size (this.Y, 2)));
      endif

      ## Initialize the score and cost matrices
      Score = nan (this.NumObservations, numel (this.ClassNames));
      Cost = nan (this.NumObservations, numel (this.ClassNames));

      ## Predict label, score, and cost (if applicable) for each KFold partition
      for k = 1:this.KFold

        ## Get data and trained model for this fold
        testIdx = test (this.Partition, k);
        model = this.Trained{k};

        ## Train
        if (no_cost)
          [predictedLabel, score] = predict (model, this.X(testIdx, :));
        else
          [predictedLabel, score, cost] = predict (model, this.X(testIdx, :));
        endif

        ## Convert cell array of labels to appropriate type (if applicable)
        if (iscell (predictedLabel))
          if (isnumeric (this.Y))
            predictedLabel = cellfun (@str2num, predictedLabel);
          elseif (islogical (this.Y))
            predictedLabel = cellfun (@logical, predictedLabel);
          elseif (iscellstr (this.Y))
            predictedLabel = predictedLabel;
          endif
        endif

        ## Get labels, score, and cost (if applicable)
        label(testIdx) = predictedLabel;
        Score(testIdx, :) = score;
        if (nargout > 2)
          Cost(testIdx, :) = cost;
        endif

      endfor

      ## Handle single fold case (holdout)
      if (this.KFold == 1)
        testIdx = test (this.Partition, 1);
        y = grp2idx (this.Y);
        label(testIdx) = this.Y(find (y == mode (y), 1));
        Score(testIdx, :) = NaN;
        Cost(testIdx, :) = NaN;
        return;
      endif

    endfunction

  endmethods

endclassdef


%!demo
%!
%! load fisheriris
%! x = meas;
%! y = species;
%!
%! ## Create a KNN classifier model
%! obj = fitcknn (x, y, "NumNeighbors", 5, "Standardize", 1);
%!
%! ## Create a partition for 5-fold cross-validation
%! partition = cvpartition (y, "KFold", 5);
%!
%! ## Create the ClassificationPartitionedModel object
%! cvModel = crossval (obj, 'cvPartition', partition)

%!demo
%!
%! load fisheriris
%! x = meas;
%! y = species;
%!
%! ## Create a KNN classifier model
%! obj = fitcknn (x, y, "NumNeighbors", 5, "Standardize", 1);
%!
%! ## Create the ClassificationPartitionedModel object
%! cvModel = crossval (obj);
%!
%! ## Predict the class labels for the observations not used for training
%! [label, score, cost] = kfoldPredict (cvModel);
%! fprintf ("Cross-validated accuracy = %1.2f%% (%d/%d)\n", ...
%!          sum (strcmp (label, y)) / numel (y) *100, ...
%!          sum (strcmp (label, y)), numel (y))

## Tests
%!test
%! load fisheriris
%! a = fitcdiscr (meas, species, "gamma", 0.3);
%! cvModel = crossval (a, "KFold", 5);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (cvModel.NumObservations, 150);
%! assert (numel (cvModel.Trained), 5);
%! assert (class (cvModel.Trained{1}), "CompactClassificationDiscriminant");
%! assert (cvModel.CrossValidatedModel, "ClassificationDiscriminant");
%! assert (cvModel.KFold, 5);
%!test
%! load fisheriris
%! a = fitcdiscr (meas, species, "gamma", 0.5, "fillcoeffs", "off");
%! cvModel = crossval (a, "HoldOut", 0.3);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert ({cvModel.X, cvModel.Y}, {meas, species});
%! assert (cvModel.NumObservations, 150);
%! assert (numel (cvModel.Trained), 1);
%! assert (class (cvModel.Trained{1}), "CompactClassificationDiscriminant");
%! assert (cvModel.CrossValidatedModel, "ClassificationDiscriminant");
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = ["a"; "a"; "b"; "b"];
%! a = fitcgam (x, y, "Interactions", "all");
%! cvModel = crossval (a, "KFold", 2);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (cvModel.NumObservations, 4);
%! assert (numel (cvModel.Trained), 2);
%! assert (class (cvModel.Trained{1}), "CompactClassificationGAM");
%! assert (cvModel.CrossValidatedModel, "ClassificationGAM");
%! assert (cvModel.KFold, 2);
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = ["a"; "a"; "b"; "b"];
%! a = fitcgam (x, y);
%! cvModel = crossval (a, "LeaveOut", "on");
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert ({cvModel.X, cvModel.Y}, {x, y});
%! assert (cvModel.NumObservations, 4);
%! assert (numel (cvModel.Trained), 4);
%! assert (class (cvModel.Trained{1}), "CompactClassificationGAM");
%! assert (cvModel.CrossValidatedModel, "ClassificationGAM");
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = ["a"; "a"; "b"; "b"];
%! a = fitcknn (x, y);
%! partition = cvpartition (y, "KFold", 2);
%! cvModel = ClassificationPartitionedModel (a, partition);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (class (cvModel.Trained{1}), "ClassificationKNN");
%! assert (cvModel.NumObservations, 4);
%! assert (cvModel.ModelParameters.NumNeighbors, 1);
%! assert (cvModel.ModelParameters.NSMethod, "kdtree");
%! assert (cvModel.ModelParameters.Distance, "euclidean");
%! assert (! cvModel.ModelParameters.Standardize);
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = ["a"; "a"; "b"; "b"];
%! a = fitcknn (x, y, "NSMethod", "exhaustive");
%! partition = cvpartition (y, "HoldOut", 0.2);
%! cvModel = ClassificationPartitionedModel (a, partition);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (class (cvModel.Trained{1}), "ClassificationKNN");
%! assert ({cvModel.X, cvModel.Y}, {x, y});
%! assert (cvModel.NumObservations, 4);
%! assert (cvModel.ModelParameters.NumNeighbors, 1);
%! assert (cvModel.ModelParameters.NSMethod, "exhaustive");
%! assert (cvModel.ModelParameters.Distance, "euclidean");
%! assert (! cvModel.ModelParameters.Standardize);
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = ["a"; "a"; "b"; "b"];
%! k = 2;
%! a = fitcknn (x, y, "NumNeighbors" ,k);
%! partition = cvpartition (numel (y), "LeaveOut");
%! cvModel = ClassificationPartitionedModel (a, partition);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (class (cvModel.Trained{1}), "ClassificationKNN");
%! assert ({cvModel.X, cvModel.Y}, {x, y});
%! assert (cvModel.NumObservations, 4);
%! assert (cvModel.ModelParameters.NumNeighbors, k);
%! assert (cvModel.ModelParameters.NSMethod, "kdtree");
%! assert (cvModel.ModelParameters.Distance, "euclidean");
%! assert (! cvModel.ModelParameters.Standardize);
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = {"a"; "a"; "b"; "b"};
%! a = fitcnet (x, y, "IterationLimit", 50);
%! cvModel = crossval (a, "KFold", 2);
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert (cvModel.NumObservations, 4);
%! assert (numel (cvModel.Trained), 2);
%! assert (class (cvModel.Trained{1}), "CompactClassificationNeuralNetwork");
%! assert (cvModel.CrossValidatedModel, "ClassificationNeuralNetwork");
%! assert (cvModel.KFold, 2);
%!test
%! x = [1, 2, 3; 4, 5, 6; 7, 8, 9; 3, 2, 1];
%! y = {"a"; "a"; "b"; "b"};
%! a = fitcnet (x, y, "LayerSizes", [5, 3]);
%! cvModel = crossval (a, "LeaveOut", "on");
%! assert (class (cvModel), "ClassificationPartitionedModel");
%! assert ({cvModel.X, cvModel.Y}, {x, y});
%! assert (cvModel.NumObservations, 4);
%! assert (numel (cvModel.Trained), 4);
%! assert (class (cvModel.Trained{1}), "CompactClassificationNeuralNetwork");
%! assert (cvModel.CrossValidatedModel, "ClassificationNeuralNetwork");
%!test
%! load fisheriris
%! inds = ! strcmp (species, 'setosa');
%! x = meas(inds, 3:4);
%! y = grp2idx (species(inds));
%! SVMModel = fitcsvm (x,y);
%! CVMdl = crossval (SVMModel, "KFold", 5);
%! assert (class (CVMdl), "ClassificationPartitionedModel")
%! assert ({CVMdl.X, CVMdl.Y}, {x, y})
%! assert (CVMdl.KFold == 5)
%! assert (class (CVMdl.Trained{1}), "CompactClassificationSVM")
%! assert (CVMdl.CrossValidatedModel, "ClassificationSVM");
%!test
%! load fisheriris
%! inds = ! strcmp (species, 'setosa');
%! x = meas(inds, 3:4);
%! y = grp2idx (species(inds));
%! obj = fitcsvm (x, y);
%! CVMdl = crossval (obj, "HoldOut", 0.2);
%! assert (class (CVMdl), "ClassificationPartitionedModel")
%! assert ({CVMdl.X, CVMdl.Y}, {x, y})
%! assert (class (CVMdl.Trained{1}), "CompactClassificationSVM")
%! assert (CVMdl.CrossValidatedModel, "ClassificationSVM");
%!test
%! load fisheriris
%! inds = ! strcmp (species, 'setosa');
%! x = meas(inds, 3:4);
%! y = grp2idx (species(inds));
%! obj = fitcsvm (x, y);
%! CVMdl = crossval (obj, "LeaveOut", 'on');
%! assert (class (CVMdl), "ClassificationPartitionedModel")
%! assert ({CVMdl.X, CVMdl.Y}, {x, y})
%! assert (class (CVMdl.Trained{1}), "CompactClassificationSVM")
%! assert (CVMdl.CrossValidatedModel, "ClassificationSVM");

## Test input validation for ClassificationPartitionedModel
%!error<ClassificationPartitionedModel: too few input arguments.> ...
%! ClassificationPartitionedModel ()
%!error<ClassificationPartitionedModel: too few input arguments.> ...
%! ClassificationPartitionedModel (ClassificationKNN (ones (4,2), ones (4,1)))
%!error<ClassificationPartitionedModel: unsupported model type.> ...
%! ClassificationPartitionedModel (RegressionGAM (ones (40,2), ...
%! randi ([1, 2], 40, 1)), cvpartition (randi ([1, 2], 40, 1), 'Holdout', 0.3))
%!error<ClassificationPartitionedModel: invalid 'cvpartition' object.> ...
%! ClassificationPartitionedModel (ClassificationKNN (ones (4,2), ...
%! ones (4,1)), 'Holdout')

## Test for kfoldPredict
%!test
%! load fisheriris
%! a = fitcdiscr (meas, species, "gamma", 0.5, "fillcoeffs", "off");
%! cvModel = crossval (a, "Kfold", 4);
%! [label, score, cost] = kfoldPredict (cvModel);
%! assert (class(cvModel), "ClassificationPartitionedModel");
%! assert ({cvModel.X, cvModel.Y}, {meas, species});
%! assert (cvModel.NumObservations, 150);
%!# assert (label, {"b"; "b"; "a"; "a"});
%!# assert (score, [4.5380e-01, 5.4620e-01; 2.4404e-01, 7.5596e-01; ...
%!#         9.9392e-01, 6.0844e-03; 9.9820e-01, 1.8000e-03], 1e-4);
%!# assert (cost, [5.4620e-01, 4.5380e-01; 7.5596e-01, 2.4404e-01; ...
%!#         6.0844e-03, 9.9392e-01; 1.8000e-03, 9.9820e-01], 1e-4);
%!test
%! x = ones(4, 11);
%! y = {"a"; "a"; "b"; "b"};
%! k = 3;
%! a = fitcknn (x, y, "NumNeighbors", k);
%! partition = cvpartition (numel (y), "LeaveOut");
%! cvModel = ClassificationPartitionedModel (a, partition);
%! [label, score, cost] = kfoldPredict (cvModel);
%! assert (class(cvModel), "ClassificationPartitionedModel");
%! assert ({cvModel.X, cvModel.Y}, {x, y});
%! assert (cvModel.NumObservations, 4);
%! assert (cvModel.ModelParameters.NumNeighbors, k);
%! assert (cvModel.ModelParameters.NSMethod, "exhaustive");
%! assert (cvModel.ModelParameters.Distance, "euclidean");
%! assert (! cvModel.ModelParameters.Standardize);
%! assert (label, {"b"; "b"; "a"; "a"});
%! assert (score, [0.3333, 0.6667; 0.3333, 0.6667; 0.6667, 0.3333; ...
%!          0.6667, 0.3333], 1e-4);
%! assert (cost, [0.6667, 0.3333; 0.6667, 0.3333; 0.3333, 0.6667; ...
%!          0.3333, 0.6667], 1e-4);

## Test input validation for kfoldPredict
%!error<ClassificationPartitionedModel.kfoldPredict: 'Cost' output is not supported for ClassificationSVM cross validated models.> ...
%! [label, score, cost] = kfoldPredict (crossval (ClassificationSVM (ones (40,2), randi ([1, 2], 40, 1))))
%!error<ClassificationPartitionedModel.kfoldPredict: 'Cost' output is not supported for ClassificationNeuralNetwork cross validated models.> ...
%! [label, score, cost] = kfoldPredict (crossval (ClassificationNeuralNetwork (ones (40,2), randi ([1, 2], 40, 1))))
