A classification task when labels are known - determining the country of origin of cars given the cars characteristics

In this exercise we are provided with several technical characteristics (mpg, horsepower,weight, model year...) for several car's models, together with the country of origin of such models, and we would like to create a machine learning model such that the country of origin can be accurately predicted given the technical characteristics. As the information to predict is a multi-class one, this is a [classification](https://en.wikipedia.org/wiki/Statisticalclassification) task. It is a challenging exercise due to the simultaneous presence of three factors: (1) presence of missing data; (2) unbalanced data - 254 out of 406 cars are US made; (3) small dataset.

Data origin:

Field description:

  1. mpg: continuous
  2. cylinders: multi-valued discrete
  3. displacement: continuous
  4. horsepower: continuous
  5. weight: continuous
  6. acceleration: continuous
  7. model year: multi-valued discrete
  8. origin: multi-valued discrete
  9. car name: string (unique for each instance)

The car name is not used in this tutorial, so that the country is inferred only from technical data. As this field includes also the car maker, and there are several car's models from the same car maker, a more sophisticated machine learnign model could exploit this information e.g. using a bag of word encoding.

Library loading and initialisation

Activating the local environment specific to BetaML documentation

using Pkg
Pkg.activate(joinpath(@__DIR__,"..","..",".."))
  Activating environment at `~/work/BetaML.jl/BetaML.jl/docs/Project.toml`

We load a buch of packages that we'll use during this tutorial..

using Random, HTTP, Plots, CSV, DataFrames, BenchmarkTools, StableRNGs, BetaML
import DecisionTree, Flux
import Pipe: @pipe

Machine Learning workflows include stochastic components in several steps: in the data sampling, in the model initialisation and often in the models's own algorithms (and sometimes also in the prediciton step). BetaML provides a random nuber generator (RNG) in order to simplify reproducibility ( FIXEDRNG. This is nothing else than an istance of StableRNG(123) defined in the BetaML.Utils sub-module, but you can choose of course your own "fixed" RNG). See the Dealing with stochasticity section in the Getting started tutorial for details.

Here we are explicit and we use our own fixed RNG:

seed = 123 # The table at the end of this tutorial has been obtained with seeds 123, 1000 and 10000
AFIXEDRNG = StableRNG(seed)
StableRNGs.LehmerRNG(state=0x000000000000000000000000000000f7)

Data loading and preparation

To load the data from the internet our workflow is (1) Retrieve the data –> (2) Clean it –> (3) Load it –> (4) Output it as a DataFrame.

For step (1) we use HTTP.get(), for step (2) we use replace!, for steps (3) and (4) we uses the CSV package, and we use the "pip" |> operator to chain these operations, so that no file is ever saved on disk:

urlDataOriginal = "https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data-original"
data = @pipe HTTP.get(urlDataOriginal).body                                                |>
             replace!(_, UInt8('\t') => UInt8(' '))                                        |> # the original dataset has mixed field delimiters !
             CSV.File(_, delim=' ', missingstring="NA", ignorerepeated=true, header=false) |>
             DataFrame;

This results in a table where the rows are the observations (the various cars' models) and the column the fields. All BetaML models expect this layout.

As the dataset is ordered, we randomly shuffle the data.

idx = randperm(copy(AFIXEDRNG),size(data,1))
data[idx, :]
describe(data)
9×7 DataFrame
Rowvariablemeanminmedianmaxnmissingeltype
SymbolUnion…AnyUnion…AnyInt64Type
1Column123.51469.023.046.68Union{Missing, Float64}
2Column25.475373.04.08.00Float64
3Column3194.7868.0151.0455.00Float64
4Column4105.08246.095.0230.06Union{Missing, Float64}
5Column52979.411613.02822.55140.00Float64
6Column615.51978.015.524.80Float64
7Column775.921270.076.082.00Float64
8Column81.568971.01.03.00Float64
9Column9amc ambassador broughamvw rabbit custom0String

Columns 1 to 7 contain characteristics of the car, while column 8 encodes the country or origin ("1" -> US, "2" -> EU, "3" -> Japan). That's the variable we want to be able to predict.

Columns 9 contains the car name, but we are not going to use this information in this tutorial. Note also that some fields have missing data.

Our first step is hence to divide the dataset in features (the x) and the labels (the y) we want to predict. The x is then a Julia standard Matrix of 406 rows by 7 columns and the y is a vector of the 406 observations:

x     = Matrix{Union{Missing,Float64}}(data[:,1:7]);
y     = Vector{Int64}(data[:,8]);
x     = fit!(Scaler(),x)
406×7 Matrix{Union{Missing, Float64}}:
 -0.706439   1.47635    1.07088    0.643526   0.620107   -1.25708   -1.58146
 -1.09075    1.47635    1.48121    1.54744    0.843522   -1.43566   -1.58146
 -0.706439   1.47635    1.17584    1.16005    0.539725   -1.61424   -1.58146
 -0.962647   1.47635    1.04225    1.16005    0.536179   -1.25708   -1.58146
 -0.834543   1.47635    1.02316    0.901788   0.555092   -1.79281   -1.58146
 -1.09075    1.47635    2.23507    2.39971    1.60951    -1.97139   -1.58146
 -1.21885    1.47635    2.47364    2.96789    1.62488    -2.32855   -1.58146
 -1.21885    1.47635    2.34004    2.83876    1.57523    -2.50712   -1.58146
 -1.21885    1.47635    2.48318    3.09702    1.70881    -1.97139   -1.58146
 -1.09075    1.47635    1.86291    2.1931     1.02911    -2.50712   -1.58146
  ⋮                                                       ⋮         
 -0.194023   0.306793   0.35518    0.178653  -0.17071    -0.292762   1.62356
  1.08702   -0.862764  -0.484569  -0.234567  -0.371665   -0.578486   1.62356
  1.59943   -0.862764  -0.570453  -0.544482  -0.720381   -0.899925   1.62356
  0.446497  -0.862764  -0.417771  -0.389524  -0.0347697   0.635842   1.62356
  0.446497  -0.862764  -0.52274   -0.492829  -0.223904    0.028678   1.62356
  2.62426   -0.862764  -0.933072  -1.37092   -1.00408     3.24307    1.62356
  1.08702   -0.862764  -0.570453  -0.544482  -0.809037   -1.39994    1.62356
  0.574601  -0.862764  -0.713592  -0.673613  -0.418948    1.10014    1.62356
  0.958913  -0.862764  -0.723135  -0.596135  -0.30665     1.38587    1.62356

Some algorithms that we will use today don't accept missing data, so we need to impute them. BetaML provides several imputation models in the Imputation module. Note that many of these imputation models can be used for Collaborative Filtering / Recomendation Systems. Models as GaussianMixtureImputer have the advantage over traditional algorithms as k-nearest neighbors (KNN) that GMM can "detect" the hidden structure of the observed data, where some observation can be similar to a certain pool of other observvations for a certain characteristic, but similar to an other pool of observations for other characteristics. Here we use RandomForestImputer. While the model allows for reproducible multiple imputations (with the parameter multiple_imputation=an_integer) and multiple passages trough the various columns (fields) containing missing data (with the option recursive_passages=an_integer), we use here just a single imputation and a single passage. As all BetaML models, RandomForestImputer follows the patters m=ModelConstruction(pars); fit!(m,x,[y]); est = predict(m,x) where est can be an estimation of some labels or be some characteristics of x itself (the imputed version, as in this case, a reprojected version as in PCAEncoder), depending if the model is supervised or not. See the API user documentationfor more details. For imputers, the output ofpredictis the matrix with the imputed values replacing the missing ones, and we write here the model in a single line using a convenience feature that when the defaultcacheparameter is used in the model constructor thefit!` function returns itself the prediciton over the trained data:

x = fit!(RandomForestImputer(rng=copy(AFIXEDRNG)),x) # Same as `m = RandomForestImputer(rng=copy(AFIXEDRNG)); fit!(m,x); x= predict(m,x)`
406×7 Matrix{Float64}:
 -0.706439   1.47635    1.07088    0.643526   0.620107   -1.25708   -1.58146
 -1.09075    1.47635    1.48121    1.54744    0.843522   -1.43566   -1.58146
 -0.706439   1.47635    1.17584    1.16005    0.539725   -1.61424   -1.58146
 -0.962647   1.47635    1.04225    1.16005    0.536179   -1.25708   -1.58146
 -0.834543   1.47635    1.02316    0.901788   0.555092   -1.79281   -1.58146
 -1.09075    1.47635    2.23507    2.39971    1.60951    -1.97139   -1.58146
 -1.21885    1.47635    2.47364    2.96789    1.62488    -2.32855   -1.58146
 -1.21885    1.47635    2.34004    2.83876    1.57523    -2.50712   -1.58146
 -1.21885    1.47635    2.48318    3.09702    1.70881    -1.97139   -1.58146
 -1.09075    1.47635    1.86291    2.1931     1.02911    -2.50712   -1.58146
  ⋮                                                       ⋮         
 -0.194023   0.306793   0.35518    0.178653  -0.17071    -0.292762   1.62356
  1.08702   -0.862764  -0.484569  -0.234567  -0.371665   -0.578486   1.62356
  1.59943   -0.862764  -0.570453  -0.544482  -0.720381   -0.899925   1.62356
  0.446497  -0.862764  -0.417771  -0.389524  -0.0347697   0.635842   1.62356
  0.446497  -0.862764  -0.52274   -0.492829  -0.223904    0.028678   1.62356
  2.62426   -0.862764  -0.933072  -1.37092   -1.00408     3.24307    1.62356
  1.08702   -0.862764  -0.570453  -0.544482  -0.809037   -1.39994    1.62356
  0.574601  -0.862764  -0.713592  -0.673613  -0.418948    1.10014    1.62356
  0.958913  -0.862764  -0.723135  -0.596135  -0.30665     1.38587    1.62356

Further, some models don't work with categorical data as well, so we need to represent our y as a matrix with a separate column for each possible categorical value (the so called "one-hot" representation). For example, within a three classes field, the individual value 2 (or "Europe" for what it matters) would be represented as the vector [0 1 0], while 3 (or "Japan") would become the vector [0 0 1]. To encode as one-hot we use the OneHotEncoder in BetaML.Utils, using the same shortcut as for the imputer we used earlier:

y_oh  = fit!(OneHotEncoder(),y)
406×3 Matrix{Bool}:
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 1  0  0
 ⋮     
 1  0  0
 0  0  1
 1  0  0
 1  0  0
 1  0  0
 0  1  0
 1  0  0
 1  0  0
 1  0  0

In supervised machine learning it is good practice to partition the available data in a training, validation, and test subsets, where the first one is used to train the ML algorithm, the second one to train any eventual "hyper-parameters" of the algorithm and the test subset is finally used to evaluate the quality of the algorithm. Here, for brevity, we use only the train and the test subsets, implicitly assuming we already know the best hyper-parameters. Please refer to the regression tutorial for examples of the auto-tune feature of BetaML models to "automatically" train the hyper-parameters (hint: in most cases just add the parameter autotune=true in the model constructor), or the clustering tutorial for an example of using the cross_validation function to do it manually.

We use then the partition function in BetaML.Utils, where we can specify the different data to partition (each matrix or vector to partition must have the same number of observations) and the shares of observation that we want in each subset. Here we keep 80% of observations for training (xtrain, and ytrain) and we use 20% of them for testing (xtest, and ytest):

((xtrain,xtest),(ytrain,ytest),(ytrain_oh,ytest_oh)) = partition([x,y,y_oh],[0.8,1-0.8],rng=copy(AFIXEDRNG));

We finally set up a dataframe to store the accuracies of the various models we'll use.

results = DataFrame(model=String[],train_acc=Float64[],test_acc=Float64[])
0×3 DataFrame
Rowmodeltrain_acctest_acc
StringFloat64Float64

Random Forests

We are now ready to use our first model, the RandomForestEstimator. Random Forests build a "forest" of decision trees models and then average their predictions in order to make an overall prediction, wheter a regression or a classification.

While here the missing data has been imputed and the dataset is comprised of only numerical values, one attractive feature of BetaML RandomForestEstimator is that they can work directly with missing and categorical data without any prior processing required.

However as the labels are encoded using integers, we need also to specify the parameter force_classification=true, otherwise the model would undergo a regression job instead.

rfm      = RandomForestEstimator(force_classification=true, rng=copy(AFIXEDRNG))
RandomForestEstimator - A 30 trees Random Forest model (unfitted)

Opposite to the RandomForestImputer and OneHotEncoder models used earielr, to train a RandomForestEstimator model we need to provide it with both the training feature matrix and the associated "true" training labels. We use the same shortcut to get the training predictions directly from the fit! function. In this case the predictions correspond to the labels:

ŷtrain   = fit!(rfm,xtrain,ytrain)
325-element Vector{Dict{Int64, Float64}}:
 Dict(2 => 0.06666666666666667, 3 => 0.8666666666666666, 1 => 0.06666666666666667)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.9999999999999999)
 Dict(2 => 0.05, 3 => 0.16666666666666666, 1 => 0.7833333333333332)
 Dict(2 => 0.8666666666666666, 3 => 0.1, 1 => 0.03333333333333333)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.2333333333333333, 3 => 0.6666666666666666, 1 => 0.1)
 Dict(1 => 0.9999999999999999)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.3, 3 => 0.03333333333333333, 1 => 0.6666666666666666)
 ⋮
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.08333333333333334, 3 => 0.9166666666666665)
 Dict(3 => 0.05, 1 => 0.9499999999999998)
 Dict(1 => 0.9999999999999999)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.8999999999999999, 3 => 0.08333333333333334, 1 => 0.016666666666666666)
 Dict(1 => 0.9999999999999999)
 Dict(3 => 0.9999999999999999)
 Dict(2 => 0.05, 3 => 0.9499999999999998)

You can notice that for each record the result is reported in terms of a dictionary with the possible categories and their associated probabilities.

Warning

Only categories with non-zero probabilities are reported for each record, and being a dictionary, the order of the categories is not undefined

For example ŷtrain[1] is a Dict(2 => 0.0333333, 3 => 0.933333, 1 => 0.0333333), indicating an overhelming probability that that car model originates from Japan. To retrieve the predictions with the highest probabilities use mode(ŷ):

ŷtrain_top = mode(ŷtrain,rng=copy(AFIXEDRNG))
325-element Vector{Int64}:
 3
 1
 2
 1
 2
 1
 3
 1
 1
 1
 ⋮
 1
 3
 1
 1
 1
 2
 1
 3
 3

Why mode takes (optionally) a RNG ? I let the answer for you :-)

To obtain the predicted labels for the test set we simply run the predict function over the features of the test set:

ŷtest   = predict(rfm,xtest)
81-element Vector{Dict{Int64, Float64}}:
 Dict(2 => 0.6, 3 => 0.13333333333333333, 1 => 0.26666666666666666)
 Dict(2 => 0.6333333333333333, 3 => 0.03333333333333333, 1 => 0.3333333333333333)
 Dict(2 => 0.6499999999999999, 3 => 0.1, 1 => 0.24999999999999997)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.1, 3 => 0.3333333333333333, 1 => 0.5666666666666667)
 Dict(3 => 0.03333333333333333, 1 => 0.9666666666666666)
 Dict(2 => 0.2833333333333333, 3 => 0.7166666666666666)
 Dict(2 => 0.47222222222222215, 3 => 0.13333333333333333, 1 => 0.3944444444444444)
 Dict(2 => 0.03333333333333333, 3 => 0.03333333333333333, 1 => 0.9333333333333332)
 Dict(1 => 0.9999999999999999)
 ⋮
 Dict(2 => 0.06666666666666667, 1 => 0.9333333333333332)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.21666666666666665, 3 => 0.7166666666666666, 1 => 0.06666666666666667)
 Dict(1 => 0.9999999999999999)
 Dict(1 => 0.9999999999999999)
 Dict(2 => 0.18333333333333332, 3 => 0.5333333333333333, 1 => 0.2833333333333333)
 Dict(2 => 0.05, 3 => 0.9333333333333332, 1 => 0.016666666666666666)
 Dict(2 => 0.21666666666666665, 3 => 0.5333333333333333, 1 => 0.24999999999999997)
 Dict(1 => 0.9999999999999999)

Finally we can measure the accuracy of our predictions with the accuracy function. We don't need to explicitly use mode, as accuracy does it itself when it is passed with predictions expressed as a dictionary:

trainAccuracy,testAccuracy  = accuracy.([ytrain,ytest],[ŷtrain,ŷtest],rng=copy(AFIXEDRNG))
2-element Vector{Float64}:
 1.0
 0.7283950617283951

We are now ready to store our first model accuracies in the results dataframe:

push!(results,["RF",trainAccuracy,testAccuracy]);

The predictions are quite good, for the training set the algoritm predicted almost all cars' origins correctly, while for the testing set (i.e. those records that has not been used to train the algorithm), the correct prediction level is still quite high, at around 80% (depends on the random seed)

While accuracy can sometimes suffice, we may often want to better understand which categories our model has trouble to predict correctly. We can investigate the output of a multi-class classifier more in-deep with a ConfusionMatrix where the true values (y) are given in rows and the predicted ones () in columns, together to some per-class metrics like the precision (true class i over predicted in class i), the recall (predicted class i over the true class i) and others.

We fist build the ConfusionMatrix model, we train it with and y and then we print it (we do it here for the test subset):

cfm = ConfusionMatrix(categories_names=Dict(1=>"US",2=>"EU",3=>"Japan"),rng=copy(AFIXEDRNG))
fit!(cfm,ytest,ŷtest) # the output is by default the confusion matrix in relative terms
print(cfm)
A ConfusionMatrix BetaMLModel (fitted)

-----------------------------------------------------------------

*** CONFUSION MATRIX ***

Scores actual (rows) vs predicted (columns):

4×4 Matrix{Any}:
 "Labels"   "EU"    "US"   "Japan"
 "EU"      8       5      3
 "US"      2      44      3
 "Japan"   2       7      7
Normalised scores actual (rows) vs predicted (columns):

4×4 Matrix{Any}:
 "Labels"   "EU"       "US"      "Japan"
 "EU"      0.5        0.3125    0.1875
 "US"      0.0408163  0.897959  0.0612245
 "Japan"   0.125      0.4375    0.4375

 *** CONFUSION REPORT ***

- Accuracy:               0.7283950617283951
- Misclassification rate: 0.2716049382716049
- Number of classes:      3

  N Class   precision   recall  specificity  f1score  actual_count  predicted_count
                          TPR       TNR                 support

  1 EU          0.667    0.500        0.938    0.571           16              12
  2 US          0.786    0.898        0.625    0.838           49              56
  3 Japan       0.538    0.438        0.908    0.483           16              13

- Simple   avg.    0.664    0.612        0.824    0.631
- Weigthed avg.    0.713    0.728        0.743    0.715

-----------------------------------------------------------------
Output of `info(cm)`:
- mean_precision:	(0.6636141636141636, 0.7133586578031021)
- fitted_records:	81
- specificity:	[0.9384615384615385, 0.625, 0.9076923076923077]
- precision:	[0.6666666666666666, 0.7857142857142857, 0.5384615384615384]
- misclassification:	0.2716049382716049
- mean_recall:	(0.6118197278911565, 0.7283950617283951)
- n_categories:	3
- normalised_scores:	[0.5 0.3125 0.1875; 0.04081632653061224 0.8979591836734694 0.061224489795918366; 0.125 0.4375 0.4375]
- tn:	[61, 20, 59]
- mean_f1score:	(0.6307608100711549, 0.7152303918587445)
- actual_count:	[16, 49, 16]
- accuracy:	0.7283950617283951
- recall:	[0.5, 0.8979591836734694, 0.4375]
- f1score:	[0.5714285714285714, 0.8380952380952381, 0.4827586206896552]
- mean_specificity:	(0.8237179487179488, 0.7427587844254511)
- predicted_count:	[12, 56, 13]
- scores:	[8 5 3; 2 44 3; 2 7 7]
- tp:	[8, 44, 7]
- fn:	[8, 5, 9]
- categories:	["EU", "US", "Japan"]
- fp:	[4, 12, 6]

From the report we can see that Japanese cars have more trouble in being correctly classified, and in particular many Japanease cars are classified as US ones. This is likely a result of the class imbalance of the data set, and could be solved by balancing the dataset with various sampling tecniques before training the model.

If you prefer a more graphical approach, we can also plot the confusion matrix. In order to do so, we pick up information from the info(cfm) function. Indeed most BetaML models can be queried with info(model) to retrieve additional information, in terms of a dictionary, that is not necessary to the prediciton, but could still be relevant. Other functions that you can use with BetaML models are parameters(m) and hyperparamaeters(m).

res = info(cfm)
heatmap(string.(res["categories"]),string.(res["categories"]),res["normalised_scores"],seriescolor=cgrad([:white,:blue]),xlabel="Predicted",ylabel="Actual", title="Confusion Matrix (normalised scores)")
Example block output

Comparision with DecisionTree.jl

We now compare BetaML [RandomForestEstimator] with the random forest estimator of the package DecisionTrees.jl` random forests are similar in usage: we first "build" (train) the forest and we then make predictions out of the trained model.

# We train the model...
model = DecisionTree.build_forest(ytrain, xtrain,rng=seed)
# ..and we generate predictions and measure their error
(ŷtrain,ŷtest) = DecisionTree.apply_forest.([model],[xtrain,xtest]);
(trainAccuracy,testAccuracy) = accuracy.([ytrain,ytest],[ŷtrain,ŷtest])
push!(results,["RF (DecisionTrees.jl)",trainAccuracy,testAccuracy]);

While the accuracy on the training set is exactly the same as for BetaML random forets, DecisionTree.jl random forests are slighly less accurate in the testing sample. Where however DecisionTrees.jl excell is in the efficiency: they are extremelly fast and memory thrifty, even if we should consider also the resources needed to impute the missing values, as they don't work with missing data.

Also, one of the reasons DecisionTrees are such efficient is that internally the data is sorted to avoid repeated comparision, but in this way they work only with features that are sortable, while BetaML random forests accept virtually any kind of input without the needs to process it.

Neural network

Neural networks (NN) can be very powerfull, but have two "inconvenients" compared with random forests: first, are a bit "picky". We need to do a bit of work to provide data in specific format. Note that this is not feature engineering. One of the advantages on neural network is that for the most this is not needed for neural networks. However we still need to "clean" the data. One issue is that NN don't like missing data. So we need to provide them with the feature matrix "clean" of missing data. Secondly, they work only with numerical data. So we need to use the one-hot encoding we saw earlier. Further, they work best if the features are scaled such that each feature has mean zero and standard deviation 1. This is why we scaled the data back at the beginning of this tutorial.

We firt measure the dimensions of our data in input (i.e. the column of the feature matrix) and the dimensions of our output, i.e. the number of categories or columns in out one-hot encoded y.

D               = size(xtrain,2)
classes         = unique(y)
nCl             = length(classes)
3

The second "inconvenient" of NN is that, while not requiring feature engineering, they still need a bit of practice on the way the structure of the network is built . It's not as simple as fit!(Model(),x,y) (altougth BetaML provides a "default" neural network structure that can be used, it isn't often adapted to the specific task). We need instead to specify how we want our layers, chain the layers together and then decide a loss overall function. Only when we done these steps, we have the model ready for training. Here we define 2 DenseLayer where, for each of them, we specify the number of neurons in input (the first layer being equal to the dimensions of the data), the output layer (for a classification task, the last layer output size beying equal to the number of classes) and an activation function for each layer (default the identity function).

ls   = 50 # number of neurons in the inned layer
l1   = DenseLayer(D,ls,f=relu,rng=copy(AFIXEDRNG))
l2   = DenseLayer(ls,nCl,f=relu,rng=copy(AFIXEDRNG))
DenseLayer{typeof(relu), typeof(drelu), Float64}([-0.2146463925907584 -0.3077087587320811 … -0.28256208289474877 0.21510681158042189; -0.08916953797649538 -0.041727530915651345 … -0.30444064706465346 -0.22349634154766507; 0.11376391271810127 -0.011244515923068743 … 0.12916068649773038 -0.2518581440082599], [0.2918467648814228, -0.004167534280141383, 0.29060333096888613], BetaML.Utils.relu, BetaML.Utils.drelu)

For a classification task, the last layer is a VectorFunctionLayer that has no learnable parameters but whose activation function is applied to the ensemble of the neurons, rather than individually on each neuron. In particular, for classification we pass the softmax function whose output has the same size as the input (i.e. the number of classes to predict), but we can use the VectorFunctionLayer with any function, including the pool1d function to create a "pooling" layer (using maximum, mean or whatever other sub-function we pass to pool1d)

l3   = VectorFunctionLayer(nCl,f=softmax) ## Add a (parameterless) layer whose activation function (softmax in this case) is defined to all its nodes at once
VectorFunctionLayer{0, typeof(softmax), typeof(dsoftmax), Nothing, Float64}(fill(NaN), 3, 3, BetaML.Utils.softmax, BetaML.Utils.dsoftmax, nothing)

Finally we chain the layers and assign a loss function and the number of epochs we want to train the model to the constructor of NeuralNetworkEstimator:

nn = NeuralNetworkEstimator(layers=[l1,l2,l3],loss=crossentropy,rng=copy(AFIXEDRNG),epochs=500)
NeuralNetworkEstimator - A Feed-forward neural network (unfitted)

Aside the layer structure and size and the number of epochs, other hyper-parameters you may want to try are the batch_size and the optimisation algoritm to employ (opt_alg).

Now we can train our network:

ŷtrain = fit!(nn, xtrain, ytrain_oh)
325×3 Matrix{Float64}:
 0.0405745  0.0392448    0.920181
 0.999942   2.87644e-5   2.87644e-5
 0.0107026  0.970008     0.0192891
 0.631465   0.268689     0.0998457
 0.0022812  0.947058     0.0506612
 1.0        1.47862e-9   1.47862e-9
 0.223431   0.162307     0.614262
 1.0        2.33936e-13  2.33936e-13
 0.999999   2.91209e-7   2.91209e-7
 0.519705   0.421513     0.0587829
 ⋮                       
 1.0        2.77161e-11  2.77161e-11
 0.0969489  0.359861     0.54319
 1.0        1.83722e-7   1.83722e-7
 0.998329   0.000555534  0.00111565
 1.0        1.09483e-10  1.09483e-10
 0.69673    0.289012     0.0142584
 1.0        4.9083e-14   4.9083e-14
 0.0409779  0.158884     0.800138
 0.0347578  0.0128461    0.952396

Predictions are in form of a nrecords_ by nclasses_ matrix of the probabilities of each record being in that class. To retrieve the classes with the highest probabilities we can use again the mode function:

ŷtrain_top = mode(ŷtrain)
325-element Vector{Int64}:
 3
 1
 2
 1
 2
 1
 3
 1
 1
 1
 ⋮
 1
 3
 1
 1
 1
 1
 1
 3
 3

Once trained, we can predict the test labels. As the trained was based on the scaled feature matrix, so must be for the predictions

ŷtest  = predict(nn,xtest)
81×3 Matrix{Float64}:
 0.1073       0.882895     0.00980417
 4.4529e-5    0.999951     4.21883e-6
 0.0671191    0.0866024    0.846279
 1.0          7.19417e-11  7.19417e-11
 0.163052     0.0179888    0.818959
 1.0          1.18774e-8   1.18774e-8
 0.0716439    0.0727605    0.855596
 0.135011     0.568851     0.296137
 0.902347     0.0410031    0.0566501
 1.0          1.45955e-10  1.45955e-10
 ⋮                         
 0.966316     0.0214983    0.0121853
 1.0          3.28963e-12  3.28963e-12
 0.0424253    0.0869417    0.870633
 1.0          2.16082e-10  2.16082e-10
 1.0          7.30083e-10  7.30083e-10
 0.443724     0.0394151    0.516861
 0.019245     0.0684049    0.91235
 0.000209648  0.998533     0.00125736
 1.0          1.29266e-7   1.29266e-7

And finally we can measure the accuracies and store the accuracies in the result dataframe:

trainAccuracy, testAccuracy   = accuracy.([ytrain,ytest],[ŷtrain,ŷtest],rng=copy(AFIXEDRNG))
push!(results,["NN",trainAccuracy,testAccuracy]);
cfm = ConfusionMatrix(categories_names=Dict(1=>"US",2=>"EU",3=>"Japan"),rng=copy(AFIXEDRNG))
fit!(cfm,ytest,ŷtest)
print(cfm)
res = info(cfm)
heatmap(string.(res["categories"]),string.(res["categories"]),res["normalised_scores"],seriescolor=cgrad([:white,:blue]),xlabel="Predicted",ylabel="Actual", title="Confusion Matrix (normalised scores)")
Example block output

While accuracies are a bit lower, the distribution of misclassification is similar, with many Jamanease cars misclassified as US ones (here we have also some EU cars misclassified as Japanease ones).

Comparisons with Flux

As we did for Random Forests, we compare BetaML neural networks with the leading package for deep learning in Julia, Flux.jl.

In Flux the input must be in the form (fields, observations), so we transpose our original matrices

xtrainT, ytrain_ohT = transpose.([xtrain, ytrain_oh])
xtestT, ytest_ohT   = transpose.([xtest, ytest_oh])
2-element Vector{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}:
 [-0.9370258544446618 0.8308089913068687 … 1.6506744270177232 -0.5783347263211628; 0.30679255888470214 -0.8627640505724731 … -0.27798574584388547 0.30679255888470214; … ; 0.10010898192256414 2.2430393858184523 … 1.564444757918087 -0.007037538272230526; 0.5552223059557987 1.0893935292213297 … 1.0893935292213297 -1.31437697547356]
 [0.0 0.0 … 0.0 1.0; 1.0 1.0 … 1.0 0.0; 0.0 0.0 … 0.0 0.0]

We define the Flux neural network model in a similar way than BetaML and load it with data, we train it, predict and measure the accuracies on the training and the test sets:

We fix the random seed for Flux, altough you may still get different results depending on the number of threads used.. this is a problem we solve in BetaML with generate_parallel_rngs.

Random.seed!(seed)

l1         = Flux.Dense(D,ls,Flux.relu)
l2         = Flux.Dense(ls,nCl,Flux.relu)
Flux_nn    = Flux.Chain(l1,l2)
fluxloss(x, y) = Flux.logitcrossentropy(Flux_nn(x), y)
ps         = Flux.params(Flux_nn)
nndata     = Flux.Data.DataLoader((xtrainT, ytrain_ohT),shuffle=true)
begin for i in 1:500  Flux.train!(fluxloss, ps, nndata, Flux.ADAM()) end end
ŷtrain     = Flux.onecold(Flux_nn(xtrainT),1:3)
ŷtest      = Flux.onecold(Flux_nn(xtestT),1:3)
trainAccuracy, testAccuracy   = accuracy.([ytrain,ytest],[ŷtrain,ŷtest])
2-element Vector{Float64}:
 0.9784615384615385
 0.7654320987654321
push!(results,["NN (Flux.jl)",trainAccuracy,testAccuracy]);

While the train accuracy is little bit higher that BetaML, the test accuracy remains comparable

Perceptron-like classifiers.

We finaly test 3 "perceptron-like" classifiers, the "classical" Perceptron (PerceptronClassifier), one of the first ML algorithms (a linear classifier), a "kernellised" version of it (KernelPerceptronClassifier, default to using the radial kernel) and "PegasosClassifier" (PegasosClassifier) another linear algorithm that starts considering a gradient-based optimisation, altought without the regularisation term as in the Support Vector Machines (SVM).

As for the previous classifiers we construct the model object, we train and predict and we compute the train and test accuracies:

pm = PerceptronClassifier(rng=copy(AFIXEDRNG))
ŷtrain = fit!(pm, xtrain, ytrain)
ŷtest  = predict(pm, xtest)
(trainAccuracy,testAccuracy) = accuracy.([ytrain,ytest],[ŷtrain,ŷtest])
push!(results,["Perceptron",trainAccuracy,testAccuracy]);

kpm = KernelPerceptronClassifier(rng=copy(AFIXEDRNG))
ŷtrain = fit!(kpm, xtrain, ytrain)
ŷtest  = predict(kpm, xtest)
(trainAccuracy,testAccuracy) = accuracy.([ytrain,ytest],[ŷtrain,ŷtest])
push!(results,["KernelPerceptronClassifier",trainAccuracy,testAccuracy]);


pegm = PegasosClassifier(rng=copy(AFIXEDRNG))
ŷtrain = fit!(pegm, xtrain, ytrain)
ŷtest  = predict(pm, xtest)
(trainAccuracy,testAccuracy) = accuracy.([ytrain,ytest],[ŷtrain,ŷtest])
push!(results,["Pegasaus",trainAccuracy,testAccuracy]);
Running function BetaML.Perceptron.#perceptronBinary#8 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_classic.jl:150
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training perceptron for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.24307692307692308
Avg. error after iteration 100 : 0.2123076923076923
Avg. error after iteration 200 : 0.18461538461538463
Avg. error after iteration 300 : 0.19076923076923077
Avg. error after iteration 400 : 0.19692307692307692
Avg. error after iteration 500 : 0.18769230769230769
Avg. error after iteration 600 : 0.19692307692307692
Avg. error after iteration 700 : 0.17846153846153845
Avg. error after iteration 800 : 0.18153846153846154
Avg. error after iteration 900 : 0.2123076923076923
Avg. error after iteration 1000 : 0.2246153846153846
Running function BetaML.Perceptron.#perceptronBinary#8 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_classic.jl:150
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training perceptron for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.27076923076923076
Avg. error after iteration 100 : 0.11076923076923077
Avg. error after iteration 200 : 0.16307692307692306
Avg. error after iteration 300 : 0.15384615384615385
Avg. error after iteration 400 : 0.16
Avg. error after iteration 500 : 0.13230769230769232
Avg. error after iteration 600 : 0.14461538461538462
Avg. error after iteration 700 : 0.13846153846153847
Avg. error after iteration 800 : 0.13846153846153847
Avg. error after iteration 900 : 0.13538461538461538
Avg. error after iteration 1000 : 0.1753846153846154
Running function BetaML.Perceptron.#perceptronBinary#8 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_classic.jl:150
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training perceptron for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.21846153846153846
Avg. error after iteration 100 : 0.19076923076923077
Avg. error after iteration 200 : 0.1723076923076923
Avg. error after iteration 300 : 0.17846153846153845
Avg. error after iteration 400 : 0.1723076923076923
Avg. error after iteration 500 : 0.19384615384615383
Avg. error after iteration 600 : 0.15076923076923077
Avg. error after iteration 700 : 0.14461538461538462
Avg. error after iteration 800 : 0.1753846153846154
Avg. error after iteration 900 : 0.1723076923076923
Avg. error after iteration 1000 : 0.18461538461538463
Running function BetaML.Perceptron.#kernel_perceptron_classifier_binary#17 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_kernel.jl:133
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training kernel perceptron for maximum 100 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.15671641791044777

Training Kernel Perceptron...   4%|▉                     |  ETA: 0:00:14
Training Kernel Perceptron...   8%|█▊                    |  ETA: 0:00:13Avg. error after iteration 10 : 0.055970149253731345

Training Kernel Perceptron...  12%|██▋                   |  ETA: 0:00:13
Training Kernel Perceptron...  16%|███▌                  |  ETA: 0:00:12Avg. error after iteration 20 : 0.05970149253731343

Training Kernel Perceptron...  20%|████▍                 |  ETA: 0:00:11
Training Kernel Perceptron...  24%|█████▎                |  ETA: 0:00:11
Training Kernel Perceptron...  28%|██████▏               |  ETA: 0:00:10Avg. error after iteration 30 : 0.03731343283582089

Training Kernel Perceptron...  32%|███████               |  ETA: 0:00:10
Training Kernel Perceptron...  36%|███████▉              |  ETA: 0:00:09Avg. error after iteration 40 : 0.05970149253731343

Training Kernel Perceptron...  40%|████████▊             |  ETA: 0:00:09
Training Kernel Perceptron...  44%|█████████▋            |  ETA: 0:00:08
Training Kernel Perceptron...  48%|██████████▌           |  ETA: 0:00:07Avg. error after iteration 50 : 0.041044776119402986

Training Kernel Perceptron...  52%|███████████▌          |  ETA: 0:00:07
Training Kernel Perceptron...  56%|████████████▍         |  ETA: 0:00:06Avg. error after iteration 60 : 0.022388059701492536

Training Kernel Perceptron...  60%|█████████████▎        |  ETA: 0:00:06
Training Kernel Perceptron...  64%|██████████████▏       |  ETA: 0:00:05
Training Kernel Perceptron...  68%|███████████████       |  ETA: 0:00:05Avg. error after iteration 70 : 0.033582089552238806

Training Kernel Perceptron...  72%|███████████████▉      |  ETA: 0:00:04
Training Kernel Perceptron...  76%|████████████████▊     |  ETA: 0:00:03Avg. error after iteration 80 : 0.026119402985074626

Training Kernel Perceptron...  80%|█████████████████▋    |  ETA: 0:00:03
Training Kernel Perceptron...  84%|██████████████████▌   |  ETA: 0:00:02
Training Kernel Perceptron...  88%|███████████████████▍  |  ETA: 0:00:02Avg. error after iteration 90 : 0.033582089552238806

Training Kernel Perceptron...  92%|████████████████████▎ |  ETA: 0:00:01
Training Kernel Perceptron...  96%|█████████████████████▏|  ETA: 0:00:01Avg. error after iteration 100 : 0.026119402985074626

Training Kernel Perceptron... 100%|██████████████████████| Time: 0:00:14
Running function BetaML.Perceptron.#kernel_perceptron_classifier_binary#17 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_kernel.jl:133
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training kernel perceptron for maximum 100 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.4166666666666667
Avg. error after iteration 10 : 0.13333333333333333

Training Kernel Perceptron...  18%|████                  |  ETA: 0:00:02Avg. error after iteration 20 : 0.1
Avg. error after iteration 30 : 0.09166666666666666

Training Kernel Perceptron...  37%|████████▏             |  ETA: 0:00:02Avg. error after iteration 40 : 0.08333333333333333
*** Avg. error after epoch 49 : 0.0 (all elements of the set has been correctly classified)

Training Kernel Perceptron... 100%|██████████████████████| Time: 0:00:01
Running function BetaML.Perceptron.#kernel_perceptron_classifier_binary#17 at /home/runner/work/BetaML.jl/BetaML.jl/src/Perceptron/Perceptron_kernel.jl:133
Type `]dev BetaML` to modify the source code (this would change its location on disk)
***
*** Training kernel perceptron for maximum 100 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.16793893129770993

Training Kernel Perceptron...   4%|▉                     |  ETA: 0:00:13
Training Kernel Perceptron...   8%|█▊                    |  ETA: 0:00:12Avg. error after iteration 10 : 0.06870229007633588

Training Kernel Perceptron...  12%|██▋                   |  ETA: 0:00:12
Training Kernel Perceptron...  16%|███▌                  |  ETA: 0:00:11Avg. error after iteration 20 : 0.04198473282442748

Training Kernel Perceptron...  20%|████▍                 |  ETA: 0:00:11
Training Kernel Perceptron...  24%|█████▎                |  ETA: 0:00:10
Training Kernel Perceptron...  28%|██████▏               |  ETA: 0:00:10Avg. error after iteration 30 : 0.03816793893129771

Training Kernel Perceptron...  32%|███████               |  ETA: 0:00:09
Training Kernel Perceptron...  36%|███████▉              |  ETA: 0:00:09*** Avg. error after epoch 40 : 0.0 (all elements of the set has been correctly classified)

Training Kernel Perceptron... 100%|██████████████████████| Time: 0:00:05
***
*** Training pegasos for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.27076923076923076
Avg. error after iteration 100 : 0.2246153846153846
Avg. error after iteration 200 : 0.23692307692307693
Avg. error after iteration 300 : 0.23076923076923078
Avg. error after iteration 400 : 0.27384615384615385
Avg. error after iteration 500 : 0.23076923076923078
Avg. error after iteration 600 : 0.21846153846153846
Avg. error after iteration 700 : 0.21846153846153846
Avg. error after iteration 800 : 0.21846153846153846
Avg. error after iteration 900 : 0.24615384615384617
Avg. error after iteration 1000 : 0.24615384615384617
***
*** Training pegasos for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.3292307692307692
Avg. error after iteration 100 : 0.24615384615384617
Avg. error after iteration 200 : 0.27692307692307694
Avg. error after iteration 300 : 0.26461538461538464
Avg. error after iteration 400 : 0.23692307692307693
Avg. error after iteration 500 : 0.2676923076923077
Avg. error after iteration 600 : 0.24615384615384617
Avg. error after iteration 700 : 0.2523076923076923
Avg. error after iteration 800 : 0.25846153846153846
Avg. error after iteration 900 : 0.2523076923076923
Avg. error after iteration 1000 : 0.24923076923076923
***
*** Training pegasos for maximum 1000 iterations. Random shuffle: true
Avg. error after iteration 1 : 0.27076923076923076
Avg. error after iteration 100 : 0.23076923076923078
Avg. error after iteration 200 : 0.24615384615384617
Avg. error after iteration 300 : 0.2553846153846154
Avg. error after iteration 400 : 0.21846153846153846
Avg. error after iteration 500 : 0.26153846153846155
Avg. error after iteration 600 : 0.2523076923076923
Avg. error after iteration 700 : 0.2276923076923077
Avg. error after iteration 800 : 0.24307692307692308
Avg. error after iteration 900 : 0.2553846153846154
Avg. error after iteration 1000 : 0.24923076923076923

Summary

This is the summary of the results we had trying to predict the country of origin of the cars, based on their technical characteristics:

println(results)
7×3 DataFrame
 Row │ model                       train_acc  test_acc
     │ String                      Float64    Float64
─────┼─────────────────────────────────────────────────
   1 │ RF                           1.0       0.728395
   2 │ RF (DecisionTrees.jl)        0.981538  0.716049
   3 │ NN                           0.935385  0.728395
   4 │ NN (Flux.jl)                 0.978462  0.765432
   5 │ Perceptron                   0.735385  0.691358
   6 │ KernelPerceptronClassifier   0.978462  0.703704
   7 │ Pegasaus                     0.670769  0.691358

If you clone BetaML repository

Model accuracies on my machine with seedd 123, 1000 and 10000 respectivelly

modeltrain 1test 1train 2test 2train 3test 3
RF0.9969230.7654321.0000000.8024691.0000000.888889
RF (DecisionTrees.jl)0.9753850.7654320.9846150.7777780.9753850.864198
NN0.8861540.7283950.9169230.8271600.8953850.876543
│ NN (Flux.jl)0.7938460.6543210.9384620.7901230.9353850.851852
│ Perceptron0.7784620.7037040.7200000.7530860.6707690.654321
│ KernelPerceptronClassifier0.9876920.7037040.9784620.7777780.9446150.827160
│ Pegasaus0.7323080.7037040.6338460.7530860.5753850.654321

We warn that this table just provides a rought idea of the various algorithms performances. Indeed there is a large amount of stochasticity both in the sampling of the data used for training/testing and in the initial settings of the parameters of the algorithm. For a statistically significant comparision we would have to repeat the analysis with multiple sampling (e.g. by cross-validation, see the clustering tutorial for an example) and initial random parameters.

Neverthless the table above shows that, when we compare BetaML with the algorithm-specific leading packages, we found similar results in terms of accuracy, but often the leading packages are better optimised and run more efficiently (but sometimes at the cost of being less verstatile). Also, for this dataset, Random Forests seems to remain marginally more accurate than Neural Network, altought of course this depends on the hyper-parameters and, with a single run of the models, we don't know if this difference is significant.

View this file on Github.


This page was generated using Literate.jl.