Documentation Center

  • Trial Software
  • Product Updates

Supervised Learning (Machine Learning) Workflow and Algorithms

Steps in Supervised Learning (Machine Learning)

Supervised learning (machine learning) takes a known set of input data and known responses to the data, and seeks to build a predictor model that generates reasonable predictions for the response to new data.

Suppose you want to predict if someone will have a heart attack within a year. You have a set of data on previous people, including age, weight, height, blood pressure, etc. You know if the previous people had heart attacks within a year of their data measurements. So the problem is combining all the existing data into a model that can predict whether a new person will have a heart attack within a year.

Supervised learning splits into two broad categories:

  • Classification for responses that can have just a few known values, such as 'true' or 'false'. Classification algorithms apply to nominal, not ordinal response values.

  • Regression for responses that are a real number, such as miles per gallon for a particular car.

You can have trouble deciding whether you have a classification problem or a regression problem. In that case, create a regression model first, because they are often more computationally efficient.

While there are many Statistics Toolbox™ algorithms for supervised learning, most use the same basic workflow for obtaining a predictor model. (Detailed instruction on the steps for ensemble learning is in Framework for Ensemble Learning.) The steps for supervised learning are:

  1. Prepare Data

  2. Choose an Algorithm

  3. Fit a Model

  4. Choose a Validation Method

  5. Examine Fit and Update Until Satisfied

  6. Use Fitted Model for Predictions

Prepare Data

All supervised learning methods start with an input data matrix, usually called X here. Each row of X represents one observation. Each column of X represents one variable, or predictor. Represent missing entries with NaN values in X. Statistics Toolbox supervised learning algorithms can handle NaN values, either by ignoring them or by ignoring any row with a NaN value.

You can use various data types for response data Y. Each element in Y represents the response to the corresponding row of X. Observations with missing Y data are ignored.

  • For regression, Y must be a numeric vector with the same number of elements as the number of rows of X.

  • For classification, Y can be any of these data types. This table also contains the method of including missing entries.

    Data TypeMissing Entry
    Numeric vectorNaN
    Categorical vector<undefined>
    Character arrayRow of spaces
    Cell array of strings''
    Logical vector(Cannot represent)

Choose an Algorithm

There are tradeoffs between several characteristics of algorithms, such as:

  • Speed of training

  • Memory usage

  • Predictive accuracy on new data

  • Transparency or interpretability, meaning how easily you can understand the reasons an algorithm makes its predictions

Details of the algorithms appear in Characteristics of Algorithms. More detail about ensemble algorithms is in Choose an Applicable Ensemble Method.

Fit a Model

The fitting function you use depends on the algorithm you choose.

AlgorithmFitting Function
Classification Treesfitctree
Regression Treesfitrtree
Discriminant Analysis (classification)fitcdiscr
K-Nearest Neighbors (classification)fitcknn
Naive Bayes (classification)fitNaiveBayes
Classification or Regression Ensemblesfitensemble
Classification or Regression Ensembles in ParallelTreeBagger

Choose a Validation Method

The three main methods to examine the accuracy of the resulting fitted model are:

Examine Fit and Update Until Satisfied

After validating the model, you might want to change it for better accuracy, better speed, or to use less memory.

When satisfied with a model of some types, you can trim it using the appropriate compact method (compact for classification trees, compact for classification ensembles, compact for regression trees, compact for regression ensembles, compact for discriminant analysis). compact removes training data and pruning information, so the model uses less memory.

Use Fitted Model for Predictions

To predict classification or regression response for most fitted models, use the predict method:

Ypredicted = predict(obj,Xnew)
  • obj is the fitted model object.

  • Xnew is the new input data.

  • Ypredicted is the predicted response, either classification or regression.

For classregtree, use the eval method instead of predict.

Characteristics of Algorithms

This table shows typical characteristics of the various supervised learning algorithms. The characteristics in any particular case can vary from the listed ones. Use the table as a guide for your initial choice of algorithms, but be aware that the table can be inaccurate for some problems.

Characteristics of Supervised Learning Algorithms

AlgorithmPredictive AccuracyFitting SpeedPrediction SpeedMemory UsageEasy to InterpretHandles Categorical Predictors
TreesLowFastFastLowYesYes
SVMHighMedium***No
Naive BayesLow******YesYes
Nearest Neighbor***Fast***MediumHighNoYes***
Discriminant Analysis****FastFastLowYesNo
EnsemblesSee Suggestions for Choosing an Appropriate Ensemble Algorithm and General Characteristics of Ensemble Algorithms

* — SVM prediction speed and memory usage are good if there are few support vectors, but can be poor if there are many support vectors. When you use a kernel function, it can be difficult to interpret how SVM classifies data, though the default linear scheme is easy to interpret.

** — Naive Bayes speed and memory usage are good for simple distributions, but can be poor for kernel distributions and large data sets.

*** — Nearest Neighbor usually has good predictions in low dimensions, but can have poor predictions in high dimensions. For linear search, Nearest Neighbor does not perform any fitting. For kd-trees, Nearest Neighbor does perform fitting. Nearest Neighbor can have either continuous or categorical predictors, but not both.

**** — Discriminant Analysis is accurate when the modeling assumptions are satisfied (multivariate normal by class). Otherwise, the predictive accuracy varies.

Was this topic helpful?