Videos related to this segment (click the title to watch)
04 NN - 2A: Binary classification (15:54)
04 NN - 2B: Multinomial classification (15:1)
04 NN - 2C: Regression (6:3)
04 NN - 2D: Convolutional neural networks (13:19)

0402 - Implementation of Neural network workflows

Some stuff to set-up the environment..

julia> cd(@__DIR__)
julia> using Pkg
julia> Pkg.activate(".") # If using a Julia version different than 1.10 please uncomment and run the following line (the guarantee of reproducibility will however be lost) # Pkg.resolve() Activating project at `~/work/SPMLJ/SPMLJ/buildedDoc/04_-_NN_-_Neural_Networks`
julia> Pkg.instantiate()
julia> using Random, Plots
julia> Random.seed!(123)Random.TaskLocalRNG()
julia> ENV["DATADEPS_ALWAYS_ACCEPT"] = "true""true"

We will not run cross validation here to find the optimal hyper-parameters. The process will not be different than those we saw in the lesson on the Perceptron. Instead we focus on creating neural network models, train them based on data and evaluating their predictions. For feed-forward neural networks (both for classification and regression) we will use BetaML, while for Convolutional Neural Networks example we will use the Flux.jl package.

Feed-forward neural networks

Binary classification

Data loading...

julia> using BetaML, DelimitedFiles
julia> data = readdlm(joinpath(dirname(pathof(BetaML)),"..","test","data","binary2DData.csv"),'\t')200×3 Matrix{Float64}: -1.0 1.76 0.4 -1.0 0.979 2.24 -1.0 1.87 -0.977 -1.0 0.95 -0.151 -1.0 -0.103 0.411 -1.0 0.144 1.45 -1.0 0.761 0.122 -1.0 0.444 0.334 -1.0 1.49 -0.205 -1.0 0.313 -0.854 ⋮ 1.0 -0.256 0.977 1.0 2.04 0.343 1.0 1.01 0.528 1.0 3.65 2.16 1.0 2.57 1.78 1.0 1.65 0.384 1.0 1.71 1.24 1.0 2.86 3.14 1.0 3.47 2.85
julia> nR = size(data,1)200
julia> idx = shuffle(1:nR)200-element Vector{Int64}: 123 131 74 23 19 78 43 130 83 186 ⋮ 137 175 182 37 71 89 142 82 170
julia> data = data[idx,:]200×3 Matrix{Float64}: 1.0 1.69 0.324 1.0 0.811 1.49 -1.0 -0.913 1.12 -1.0 -0.5097 -0.4381 -1.0 1.23 1.2 -1.0 -0.0985 -0.6635 -1.0 1.49 1.9 1.0 0.417 2.61 -1.0 -1.23 0.844 1.0 2.28 1.01 ⋮ 1.0 3.96 2.39 1.0 2.58 2.35 1.0 2.93 2.34 -1.0 1.14 -1.23 -1.0 -1.49 0.439 -1.0 -0.8034 -0.6895 1.0 1.31 3.54 -1.0 0.949 0.0876 1.0 1.32 3.66
julia> X = copy(data[:,[2,3]])200×2 Matrix{Float64}: 1.69 0.324 0.811 1.49 -0.913 1.12 -0.5097 -0.4381 1.23 1.2 -0.0985 -0.6635 1.49 1.9 0.417 2.61 -1.23 0.844 2.28 1.01 ⋮ 3.96 2.39 2.58 2.35 2.93 2.34 1.14 -1.23 -1.49 0.439 -0.8034 -0.6895 1.31 3.54 0.949 0.0876 1.32 3.66
julia> y = max.(0,convert(Array{Int64,1},copy(data[:,1]))) # Converting labels from {-1,1} to {0,1}200-element Vector{Int64}: 1 1 0 0 0 0 0 1 0 1 ⋮ 1 1 1 0 0 0 1 0 1
julia> ((xtrain,xtest),(ytrain,ytest)) = partition([X,y],[0.7,0.3])2-element Vector{Vector}: AbstractMatrix{Float64}[[1.23 1.2; 3.49 -0.07; … ; 2.69 2.69; 0.979 2.24], [-0.8708 -0.5788; 0.127 0.402; … ; -0.312 0.0562; -0.674 0.0318]] AbstractVector{Int64}[[0, 1, 0, 0, 0, 0, 1, 1, 1, 1 … 1, 1, 1, 0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 1, 0, 1, 1, 1, 1, 1 … 0, 1, 1, 0, 0, 0, 1, 1, 0, 0]]

Using defaults - hidding complexity

Model definition...

julia> mynn = NeuralNetworkEstimator()NeuralNetworkEstimator - A Feed-forward neural network (unfitted)

Training...

julia> fit!(mynn,xtrain,ytrain)***
*** Training  for 200 epochs with algorithm BetaML.Nn.ADAM.
Training.. 	 avg loss on epoch 1 (1): 	 0.18656267056947534
Training.. 	 avg loss on epoch 20 (20): 	 0.058163963177969666
Training.. 	 avg loss on epoch 40 (40): 	 0.04339082068941012
Training.. 	 avg loss on epoch 60 (60): 	 0.04180323331121693
Training.. 	 avg loss on epoch 80 (80): 	 0.04089541120675939

Training the Neural Network...  40%|████████▍            |  ETA: 0:00:02Training.. 	 avg loss on epoch 100 (100): 	 0.040290207602930596
Training.. 	 avg loss on epoch 120 (120): 	 0.039671600941963626
Training.. 	 avg loss on epoch 140 (140): 	 0.037587573639194245
Training.. 	 avg loss on epoch 160 (160): 	 0.03422994951776505
Training.. 	 avg loss on epoch 180 (180): 	 0.03325984000634394
Training.. 	 avg loss on epoch 200 (200): 	 0.03258492443270572

Training the Neural Network... 100%|█████████████████████| Time: 0:00:01
Training of 200 epoch completed. Final epoch error: 0.03258492443270572.
140-element Vector{Float64}:
 0.6201206013534005
 0.8866438770892913
 0.0
 0.24963268094063873
 0.0
 0.0
 0.987000546441149
 0.8745866029237545
 1.1567629865389901
 0.8251936508889467
 ⋮
 1.1670052789733627
 0.6279771060381655
 0.0
 0.0
 0.06251959909968467
 0.0
 1.0037068233464561
 1.0425240486370215
 0.8069008386620741
julia> ŷtrain = predict(mynn, xtrain) |> makecolvector .|> round .|> Int140-element Vector{Int64}: 1 1 0 0 0 0 1 1 1 1 ⋮ 1 1 0 0 0 0 1 1 1
julia> ŷtest = predict(mynn, xtest) |> makecolvector .|> round .|> Int60-element Vector{Int64}: 0 0 0 1 0 1 1 1 1 0 ⋮ 1 1 0 0 1 1 1 0 0
julia> trainAccuracy = accuracy(ytrain,ŷtrain)0.9285714285714286
julia> testAccuracy = accuracy(ytest,ŷtest)0.95

Specifying all options

Creating a custom callback function to receive info during training...

julia> function myOwnTrainingInfo(nn,xbatch,ybatch,x,y;n,n_batches,epochs,epochs_ran,verbosity,n_epoch,n_batch)
           if verbosity == NONE
               return false # doesn't stop the training
           end
       
           nMsgDict = Dict(LOW => 0, STD => 10,HIGH => 100, FULL => n)
           nMsgs = nMsgDict[verbosity]
       
           if verbosity == FULL || ( n_batch == n_batches && ( n_epoch == 1  || n_epoch % ceil(epochs/nMsgs) == 0))
       
               ϵ = BetaML.Nn.loss(nn,x,y)
               println("Training.. \t avg loss on epoch $n_epoch ($(n_epoch+epochs_ran)): \t $(ϵ)")
           end
           return false
       endmyOwnTrainingInfo (generic function with 1 method)

Model definition...

julia> l1   = DenseLayer(2,5,f=tanh, df= dtanh,rng=copy(FIXEDRNG))BetaML.Nn.DenseLayer{typeof(tanh), typeof(BetaML.Utils.dtanh), Float64}([-0.5906259287724187 -0.030940667488043583; -0.24536094247801754 0.48878436852866713; … ; -0.8466984663657858 0.027858152254290336; -0.11481842953412624 -0.009180955255778445], [0.37280922460295074, 0.35842073435858446, 0.696721837875378, 0.8225343299333828, 0.44293565195483586], tanh, BetaML.Utils.dtanh)
julia> l2 = DenseLayer(5,5,f=relu,df=drelu,rng=copy(FIXEDRNG))BetaML.Nn.DenseLayer{typeof(BetaML.Utils.relu), typeof(BetaML.Utils.drelu), Float64}([-0.49415310523844497 -0.025886819681528617 … 0.09224620317870791 -0.2298725180847525; -0.20528369264408397 0.40894634274263597 … 0.01065168104625569 0.04243100148415224; … ; -0.7083987613359595 0.023307802404264888 … -0.3619336136169306 -0.4825559053138102; -0.09606399030062296 -0.0076813382679077336 … -0.3824319138128976 0.21378287379211514], [-0.11178452751598777, -0.2115472973462721, 0.12029035453379433, -0.06939594013486328, -0.5177091839997635], BetaML.Utils.relu, BetaML.Utils.drelu)
julia> l3 = DenseLayer(5,1,f=sigmoid,df=dsigmoid,rng=copy(FIXEDRNG))BetaML.Nn.DenseLayer{typeof(BetaML.Utils.sigmoid), typeof(BetaML.Utils.dsigmoid), Float64}([-0.6379489156883928 -0.2650201076194998 … -0.9145388683760447 -0.12401807820151456], [-0.033419740504278206], BetaML.Utils.sigmoid, BetaML.Utils.dsigmoid)
julia> mynn = NeuralNetworkEstimator(layers=[l1,l2,l3],loss=squared_cost,dloss=dsquared_cost,descr="A classification task", cb=myOwnTrainingInfo,epochs=300,batch_size=6,opt_alg=ADAM(η=t -> 0.001, λ=1.0, β₁=0.9, β₂=0.999, ϵ=1e-8),rng=copy(FIXEDRNG),verbosity=STD)NeuralNetworkEstimator - A Feed-forward neural network (unfitted)

Training...

julia> fit!(mynn,xtrain,ytrain)***
*** Training A classification task for 300 epochs with algorithm BetaML.Nn.ADAM.
Training.. 	 avg loss on epoch 1 (1): 	 0.1899818931956108
Training.. 	 avg loss on epoch 30 (30): 	 0.08696213719139437
Training.. 	 avg loss on epoch 60 (60): 	 0.03836146288656552
Training.. 	 avg loss on epoch 90 (90): 	 0.03282305762581847
Training.. 	 avg loss on epoch 120 (120): 	 0.031828414210249366
Training.. 	 avg loss on epoch 150 (150): 	 0.031356983882135134
Training.. 	 avg loss on epoch 180 (180): 	 0.031114984044989264

Training the Neural Network...  63%|█████████████▏       |  ETA: 0:00:01Training.. 	 avg loss on epoch 210 (210): 	 0.0309532077728682
Training.. 	 avg loss on epoch 240 (240): 	 0.030839304285758238
Training.. 	 avg loss on epoch 270 (270): 	 0.030752146604710585
Training.. 	 avg loss on epoch 300 (300): 	 0.030688406769744634

Training the Neural Network... 100%|█████████████████████| Time: 0:00:01
Training of 300 epoch completed. Final epoch error: 0.030688406769744634.
140-element Vector{Float64}:
 0.8318671801770324
 0.9132972249663277
 0.0007283943326994203
 0.16886055554847956
 0.007886250839655266
 0.0007665909146831172
 0.9372841421253452
 0.9259895334959491
 0.945751620627899
 0.9122735503774152
 ⋮
 0.9469085452904192
 0.8374396051938036
 0.0012373417122658948
 0.005618761363500644
 0.03882052421323333
 0.0005785658276820501
 0.9254528955236776
 0.9430631815489707
 0.8958452838054439
julia> ŷtrain = predict(mynn, xtrain) |> makecolvector .|> round .|> Int140-element Vector{Int64}: 1 1 0 0 0 0 1 1 1 1 ⋮ 1 1 0 0 0 0 1 1 1
julia> ŷtest = predict(mynn, xtest) |> makecolvector .|> round .|> Int60-element Vector{Int64}: 0 0 0 1 0 1 1 1 1 0 ⋮ 1 1 0 0 1 1 1 0 0
julia> trainAccuracy = accuracy(ŷtrain,ytrain)0.9214285714285714
julia> testAccuracy = accuracy(ŷtest,ytest)0.95

Multinomial classification

We want to determine the plant specie given some bothanic measures of the flower

julia> iris     = readdlm(joinpath(dirname(Base.find_package("BetaML")),"..","test","data","iris.csv"),',',skipstart=1)150×5 Matrix{Any}:
 5.1  3.5  1.4  0.2  "setosa"
 4.9  3.0  1.4  0.2  "setosa"
 4.7  3.2  1.3  0.2  "setosa"
 4.6  3.1  1.5  0.2  "setosa"
 5.0  3.6  1.4  0.2  "setosa"
 5.4  3.9  1.7  0.4  "setosa"
 4.6  3.4  1.4  0.3  "setosa"
 5.0  3.4  1.5  0.2  "setosa"
 4.4  2.9  1.4  0.2  "setosa"
 4.9  3.1  1.5  0.1  "setosa"
 ⋮
 6.9  3.1  5.1  2.3  "virginica"
 5.8  2.7  5.1  1.9  "virginica"
 6.8  3.2  5.9  2.3  "virginica"
 6.7  3.3  5.7  2.5  "virginica"
 6.7  3.0  5.2  2.3  "virginica"
 6.3  2.5  5.0  1.9  "virginica"
 6.5  3.0  5.2  2.0  "virginica"
 6.2  3.4  5.4  2.3  "virginica"
 5.9  3.0  5.1  1.8  "virginica"
julia> iris = iris[shuffle(axes(iris, 1)), :] # Shuffle the records, as they aren't by default150×5 Matrix{Any}: 6.4 2.8 5.6 2.2 "virginica" 5.8 2.7 5.1 1.9 "virginica" 5.0 2.3 3.3 1.0 "versicolor" 6.3 2.7 4.9 1.8 "virginica" 6.0 2.9 4.5 1.5 "versicolor" 5.4 3.9 1.7 0.4 "setosa" 6.2 2.9 4.3 1.3 "versicolor" 5.9 3.2 4.8 1.8 "versicolor" 6.3 2.5 5.0 1.9 "virginica" 5.1 3.8 1.6 0.2 "setosa" ⋮ 5.4 3.0 4.5 1.5 "versicolor" 5.8 2.7 3.9 1.2 "versicolor" 5.6 3.0 4.5 1.5 "versicolor" 7.9 3.8 6.4 2.0 "virginica" 6.3 3.3 4.7 1.6 "versicolor" 4.9 3.0 1.4 0.2 "setosa" 5.8 2.7 5.1 1.9 "virginica" 5.7 2.5 5.0 2.0 "virginica" 5.6 2.5 3.9 1.1 "versicolor"
julia> x = convert(Array{Float64,2}, iris[:,1:4])150×4 Matrix{Float64}: 6.4 2.8 5.6 2.2 5.8 2.7 5.1 1.9 5.0 2.3 3.3 1.0 6.3 2.7 4.9 1.8 6.0 2.9 4.5 1.5 5.4 3.9 1.7 0.4 6.2 2.9 4.3 1.3 5.9 3.2 4.8 1.8 6.3 2.5 5.0 1.9 5.1 3.8 1.6 0.2 ⋮ 5.4 3.0 4.5 1.5 5.8 2.7 3.9 1.2 5.6 3.0 4.5 1.5 7.9 3.8 6.4 2.0 6.3 3.3 4.7 1.6 4.9 3.0 1.4 0.2 5.8 2.7 5.1 1.9 5.7 2.5 5.0 2.0 5.6 2.5 3.9 1.1
julia> ystring = String.(iris[:, 5])150-element Vector{String}: "virginica" "virginica" "versicolor" "virginica" "versicolor" "setosa" "versicolor" "versicolor" "virginica" "setosa" ⋮ "versicolor" "versicolor" "versicolor" "virginica" "versicolor" "setosa" "virginica" "virginica" "versicolor"
julia> iemod = OrdinalEncoder()A BetaML.Utils.OrdinalEncoder BetaMLModel (unfitted)
julia> y = fit!(iemod,ystring)150-element Vector{Int64}: 1 1 2 1 2 3 2 2 1 3 ⋮ 2 2 2 1 2 3 1 1 2
julia> ((xtrain,xtest),(ytrain,ytest)) = partition([x,y],[0.8,0.2],shuffle=false)2-element Vector{Vector}: AbstractMatrix{Float64}[[6.4 2.8 5.6 2.2; 5.8 2.7 5.1 1.9; … ; 7.7 3.0 6.1 2.3; 4.7 3.2 1.3 0.2], [4.5 2.3 1.3 0.3; 6.4 3.2 5.3 2.3; … ; 5.7 2.5 5.0 2.0; 5.6 2.5 3.9 1.1]] AbstractVector{Int64}[[1, 1, 2, 1, 2, 3, 2, 2, 1, 3 … 2, 1, 2, 3, 1, 1, 1, 1, 1, 3], [3, 1, 3, 3, 1, 1, 3, 1, 2, 3 … 3, 2, 2, 2, 1, 2, 3, 1, 1, 2]]
julia> ohmod = OneHotEncoder()A BetaML.Utils.OneHotEncoder BetaMLModel (unfitted)
julia> ytrain_oh = fit!(ohmod,ytrain) # Convert to One-hot representation (e.g. 2 => [0 1 0], 3 => [0 0 1])120×3 Matrix{Bool}: 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 1 0 0 1 0 1 0 0 0 0 1 ⋮ 1 0 0 0 1 0 0 0 1 1 0 0 1 0 0 1 0 0 1 0 0 1 0 0 0 0 1

Define the Artificial Neural Network model

julia> l1   = DenseLayer(4,10,f=relu) # Activation function is ReLUBetaML.Nn.DenseLayer{typeof(BetaML.Utils.relu), typeof(BetaML.Utils.drelu), Float64}([-0.4274884114527997 -0.4214703393580118 -0.5315034621944922 -0.23724522805376036; 0.34400830455556886 -0.22045437570990367 0.1298513933364338 -0.4037155951090742; … ; -0.5558799973908989 0.30330988277947524 -0.11989701284956722 -0.10825508231193492; 0.34195419400849636 -0.1777851547267581 -0.15973384764413973 -0.5141211611752307], [0.11054802385430795, -0.6411162534358368, 0.47216134349498584, -0.43889926863378065, -0.6189037032291641, -0.6421720066850346, -0.3699487166168019, -0.2891070665545849, 0.4512367486464861, 0.141362499058042], BetaML.Utils.relu, BetaML.Utils.drelu)
julia> l2 = DenseLayer(10,3) # Activation function is identity by defaultBetaML.Nn.DenseLayer{typeof(identity), typeof(BetaML.Utils.didentity), Float64}([0.07415567087540198 -0.6125064209083598 … 0.5742810851895541 0.22500975978193438; -0.14975048960189652 0.5620247702259027 … -0.5607057628148215 -0.5461793323927855; 0.09078219414477962 -0.6713624078616475 … -0.10182567920704744 -0.3947993070660693], [-0.15767169389929858, -0.40285918160362344, -0.5458008932610446], identity, BetaML.Utils.didentity)
julia> l3 = VectorFunctionLayer(3,f=softmax) # Add a (parameterless) layer whose activation function (softMax in this case) is defined to all its nodes at onceBetaML.Nn.VectorFunctionLayer{0, typeof(BetaML.Utils.softmax), typeof(BetaML.Utils.dsoftmax), Nothing, Float64}(fill(NaN), 3, 3, BetaML.Utils.softmax, BetaML.Utils.dsoftmax, nothing)
julia> mynn = NeuralNetworkEstimator(layers=[l1,l2,l3],loss=crossentropy,batch_size=6,descr="Multinomial logistic regression Model Sepal") # Build the NN and use the squared cost (aka MSE) as error function (crossEntropy could also be used)NeuralNetworkEstimator - A Feed-forward neural network (unfitted)

Training it (default to ADAM)

julia> fit!(mynn,fit!(Scaler(),xtrain),ytrain_oh) # Use optAlg=SGD() to use Stochastic Gradient Descent instead***
*** Training Multinomial logistic regression Model Sepal for 200 epochs with algorithm BetaML.Nn.ADAM.
Training.. 	 avg loss on epoch 1 (1): 	 1.1426494186447964
Training.. 	 avg loss on epoch 20 (20): 	 0.47669911666655074
Training.. 	 avg loss on epoch 40 (40): 	 0.33743450982757367
Training.. 	 avg loss on epoch 60 (60): 	 0.2666009751090637
Training.. 	 avg loss on epoch 80 (80): 	 0.21293898370904527
Training.. 	 avg loss on epoch 100 (100): 	 0.16900942186689594
Training.. 	 avg loss on epoch 120 (120): 	 0.1362966496821417
Training.. 	 avg loss on epoch 140 (140): 	 0.11354356534023682
Training.. 	 avg loss on epoch 160 (160): 	 0.09773356773785064
Training.. 	 avg loss on epoch 180 (180): 	 0.08611110721701842
Training.. 	 avg loss on epoch 200 (200): 	 0.07740905281435076
Training of 200 epoch completed. Final epoch error: 0.07740905281435076.
120×3 Matrix{Float64}:
 0.996207    0.00378446   8.64678e-6
 0.961135    0.0387404    0.000124983
 0.00117351  0.997409     0.0014171
 0.830128    0.169711     0.000160804
 0.19175     0.80599      0.0022605
 4.24288e-6  0.000201879  0.999794
 0.0200654   0.978487     0.00144769
 0.723247    0.271466     0.0052871
 0.942235    0.0577389    2.59147e-5
 1.11053e-6  9.7819e-5    0.999901
 ⋮
 0.931859    0.0679827    0.000157993
 0.00660383  0.992495     0.000901162
 3.3439e-6   0.00261161   0.997385
 0.999552    0.000448172  3.03277e-8
 0.99763     0.0023576    1.24139e-5
 0.963468    0.0362871    0.000244754
 0.998186    0.00179402   2.00707e-5
 0.996461    0.00353701   1.50809e-6
 1.3113e-6   0.000836472  0.999162

Test it

julia> ŷtrain        = predict(mynn,fit!(Scaler(),xtrain))   # Note the scaling model120×3 Matrix{Float64}:
 0.996207    0.00378446   8.64678e-6
 0.961135    0.0387404    0.000124983
 0.00117351  0.997409     0.0014171
 0.830128    0.169711     0.000160804
 0.19175     0.80599      0.0022605
 4.24288e-6  0.000201879  0.999794
 0.0200654   0.978487     0.00144769
 0.723247    0.271466     0.0052871
 0.942235    0.0577389    2.59147e-5
 1.11053e-6  9.7819e-5    0.999901
 ⋮
 0.931859    0.0679827    0.000157993
 0.00660383  0.992495     0.000901162
 3.3439e-6   0.00261161   0.997385
 0.999552    0.000448172  3.03277e-8
 0.99763     0.0023576    1.24139e-5
 0.963468    0.0362871    0.000244754
 0.998186    0.00179402   2.00707e-5
 0.996461    0.00353701   1.50809e-6
 1.3113e-6   0.000836472  0.999162
julia> ŷtest = predict(mynn,fit!(Scaler(),xtest))30×3 Matrix{Float64}: 2.12805e-6 0.670318 0.32968 0.960689 0.0385029 0.000808501 1.96953e-7 0.000474394 0.999525 1.02506e-7 1.35004e-5 0.999986 0.974845 0.0251528 2.60157e-6 0.985162 0.0146765 0.000161203 1.6645e-7 1.65224e-5 0.999983 0.950981 0.0487237 0.000295364 0.0113644 0.980515 0.00812018 1.6474e-7 9.33256e-5 0.999907 ⋮ 0.0656364 0.908984 0.0253798 0.00216533 0.996376 0.00145897 0.049637 0.935882 0.0144808 0.796518 0.196925 0.0065564 0.090972 0.872246 0.036782 2.6614e-7 0.00113664 0.998863 0.865683 0.133994 0.000323265 0.947716 0.0522308 5.27143e-5 0.000683639 0.998784 0.000532276
julia> trainAccuracy = accuracy(ytrain,ŷtrain)0.975
julia> testAccuracy = accuracy(ytest,ŷtest,tol=1,ignorelabels=false)0.9
julia> cm = ConfusionMatrix()A BetaML.Utils.ConfusionMatrix BetaMLModel (unfitted)
julia> fit!(cm,inverse_predict(iemod,ytrain),inverse_predict(iemod,mode(ŷtrain)))3×3 Matrix{Float64}: 1.0 0.0 0.0 0.075 0.925 0.0 0.0 0.0 1.0
julia> println(cm)A BetaML.Utils.ConfusionMatrix BetaMLModel (fitted) ----------------------------------------------------------------- *** CONFUSION MATRIX *** Scores actual (rows) vs predicted (columns): 4×4 Matrix{Any}: "Labels" "virginica" "versicolor" "setosa" "virginica" 38 0 0 "versicolor" 3 37 0 "setosa" 0 0 42 Normalised scores actual (rows) vs predicted (columns): 4×4 Matrix{Any}: "Labels" "virginica" "versicolor" "setosa" "virginica" 1.0 0.0 0.0 "versicolor" 0.075 0.925 0.0 "setosa" 0.0 0.0 1.0 *** CONFUSION REPORT *** - Accuracy: 0.975 - Misclassification rate: 0.025000000000000022 - Number of classes: 3 N Class precision recall specificity f1score actual_count predicted_count TPR TNR support 1 virginica 0.927 1.000 0.963 0.962 38 41 2 versicolor 1.000 0.925 1.000 0.961 40 37 3 setosa 1.000 1.000 1.000 1.000 42 42 - Simple avg. 0.976 0.975 0.988 0.974 - Weighted avg. 0.977 0.975 0.988 0.975 ----------------------------------------------------------------- Output of `info(cm)`: - mean_precision: (0.975609756097561, 0.9768292682926829) - fitted_records: 120 - specificity: [0.9634146341463414, 1.0, 1.0] - precision: [0.926829268292683, 1.0, 1.0] - misclassification: 0.025000000000000022 - mean_recall: (0.975, 0.975) - n_categories: 3 - normalised_scores: [1.0 0.0 0.0; 0.075 0.925 0.0; 0.0 0.0 1.0] - tn: [79, 80, 78] - mean_f1score: (0.9743547591648857, 0.9749876705572909) - actual_count: [38, 40, 42] - accuracy: 0.975 - recall: [1.0, 0.925, 1.0] - f1score: [0.9620253164556962, 0.961038961038961, 1.0] - mean_specificity: (0.9878048780487805, 0.9884146341463415) - predicted_count: [41, 37, 42] - scores: [38 0 0; 3 37 0; 0 0 42] - tp: [38, 37, 42] - fn: [0, 3, 0] - categories: ["virginica", "versicolor", "setosa"] - fp: [3, 0, 0]
julia> res = info(cm)Dict{String, Any} with 21 entries: "mean_precision" => (0.97561, 0.976829) "fitted_records" => 120 "specificity" => [0.963415, 1.0, 1.0] "precision" => [0.926829, 1.0, 1.0] "misclassification" => 0.025 "mean_recall" => (0.975, 0.975) "n_categories" => 3 "normalised_scores" => [1.0 0.0 0.0; 0.075 0.925 0.0; 0.0 0.0 1.0] "tn" => [79, 80, 78] "mean_f1score" => (0.974355, 0.974988) "actual_count" => [38, 40, 42] "accuracy" => 0.975 "recall" => [1.0, 0.925, 1.0] "f1score" => [0.962025, 0.961039, 1.0] "mean_specificity" => (0.987805, 0.988415) "predicted_count" => [41, 37, 42] "scores" => [38 0 0; 3 37 0; 0 0 42] "tp" => [38, 37, 42] "fn" => [0, 3, 0] ⋮ => ⋮
julia> heatmap(string.(res["categories"]),string.(res["categories"]),res["normalised_scores"],seriescolor=cgrad([:white,:blue]),xlabel="Predicted",ylabel="Actual", title="Confusion Matrix (normalised scores)")Plot{Plots.GRBackend() n=1}
julia> savefig("cm_iris.svg");

Regression

Data Loading and processing..

julia> using Pipe, HTTP, CSV, DataFrames
julia> urlData = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt""https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"
julia> data = @pipe HTTP.get(urlData).body |> CSV.File(_, delim='\t') |> DataFrame442×11 DataFrame Row │ AGE SEX BMI BP S1 S2 S3 S4 S5 ⋯ │ Int64 Int64 Float64 Float64 Int64 Float64 Float64 Float64 Float ⋯ ─────┼────────────────────────────────────────────────────────────────────────── 1 │ 59 2 32.1 101.0 157 93.2 38.0 4.0 4.85 ⋯ 2 │ 48 1 21.6 87.0 183 103.2 70.0 3.0 3.89 3 │ 72 2 30.5 93.0 156 93.6 41.0 4.0 4.67 4 │ 24 1 25.3 84.0 198 131.4 40.0 5.0 4.89 5 │ 50 1 23.0 101.0 192 125.4 52.0 4.0 4.29 ⋯ 6 │ 23 1 22.6 89.0 139 64.8 61.0 2.0 4.18 7 │ 36 2 22.0 90.0 160 99.6 50.0 3.0 3.95 8 │ 66 2 26.2 114.0 255 185.0 56.0 4.55 4.24 ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱ 436 │ 45 1 24.2 83.0 177 118.4 45.0 4.0 4.21 ⋯ 437 │ 33 1 19.5 80.0 171 85.4 75.0 2.0 3.97 438 │ 60 2 28.2 112.0 185 113.8 42.0 4.0 4.98 439 │ 47 2 24.9 75.0 225 166.0 42.0 5.0 4.44 440 │ 60 2 24.9 99.67 162 106.6 43.0 3.77 4.12 ⋯ 441 │ 36 1 30.0 95.0 201 125.2 42.0 4.79 5.12 442 │ 36 1 19.6 71.0 250 133.2 97.0 3.0 4.59 3 columns and 427 rows omitted
julia> ohmod = OneHotEncoder()A BetaML.Utils.OneHotEncoder BetaMLModel (unfitted)
julia> sex_oh = fit!(ohmod,data.SEX)442×2 Matrix{Bool}: 0 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 ⋮ 1 0 1 0 1 0 1 0 0 1 0 1 0 1 1 0 1 0
julia> X = hcat(data.AGE, Matrix(data[:,3:10]),sex_oh)442×11 Matrix{Float64}: 59.0 32.1 101.0 157.0 93.2 38.0 4.0 4.8598 87.0 0.0 1.0 48.0 21.6 87.0 183.0 103.2 70.0 3.0 3.8918 69.0 1.0 0.0 72.0 30.5 93.0 156.0 93.6 41.0 4.0 4.6728 85.0 0.0 1.0 24.0 25.3 84.0 198.0 131.4 40.0 5.0 4.8903 89.0 1.0 0.0 50.0 23.0 101.0 192.0 125.4 52.0 4.0 4.2905 80.0 1.0 0.0 23.0 22.6 89.0 139.0 64.8 61.0 2.0 4.1897 68.0 1.0 0.0 36.0 22.0 90.0 160.0 99.6 50.0 3.0 3.9512 82.0 0.0 1.0 66.0 26.2 114.0 255.0 185.0 56.0 4.55 4.2485 92.0 0.0 1.0 60.0 32.1 83.0 179.0 119.4 42.0 4.0 4.4773 94.0 0.0 1.0 29.0 30.0 85.0 180.0 93.4 43.0 4.0 5.3845 88.0 1.0 0.0 ⋮ ⋮ ⋮ 41.0 20.8 86.0 223.0 128.2 83.0 3.0 4.0775 89.0 1.0 0.0 53.0 26.5 97.0 193.0 122.4 58.0 3.0 4.1431 99.0 1.0 0.0 45.0 24.2 83.0 177.0 118.4 45.0 4.0 4.2195 82.0 1.0 0.0 33.0 19.5 80.0 171.0 85.4 75.0 2.0 3.9703 80.0 1.0 0.0 60.0 28.2 112.0 185.0 113.8 42.0 4.0 4.9836 93.0 0.0 1.0 47.0 24.9 75.0 225.0 166.0 42.0 5.0 4.4427 102.0 0.0 1.0 60.0 24.9 99.67 162.0 106.6 43.0 3.77 4.1271 95.0 0.0 1.0 36.0 30.0 95.0 201.0 125.2 42.0 4.79 5.1299 85.0 1.0 0.0 36.0 19.6 71.0 250.0 133.2 97.0 3.0 4.5951 92.0 1.0 0.0
julia> y = data.Y442-element Vector{Int64}: 151 75 141 206 135 97 138 63 110 310 ⋮ 72 49 64 48 178 104 132 220 57
julia> (xtrain,xval),(ytrain,yval) = partition([X,y],[0.8,0.2])2-element Vector{Vector}: AbstractMatrix{Float64}[[48.0 24.1 … 0.0 1.0; 67.0 23.5 … 0.0 1.0; … ; 67.0 22.5 … 0.0 1.0; 54.0 27.3 … 0.0 1.0], [47.0 26.5 … 0.0 1.0; 54.0 24.2 … 1.0 0.0; … ; 25.0 26.0 … 1.0 0.0; 36.0 22.0 … 0.0 1.0]] AbstractVector{Int64}[[65, 172, 274, 72, 93, 317, 171, 187, 68, 186 … 279, 111, 230, 128, 71, 178, 31, 272, 75, 235], [51, 92, 216, 276, 71, 281, 182, 242, 293, 65 … 49, 243, 189, 178, 42, 264, 229, 244, 68, 138]]

Model definition...

julia> l1   = DenseLayer(11,20,f=relu)BetaML.Nn.DenseLayer{typeof(BetaML.Utils.relu), typeof(BetaML.Utils.drelu), Float64}([0.05186746611701665 0.10038465866909702 … 0.06442070381894544 -0.4216689104148587; 0.1266579040822326 -0.43271947738078226 … 0.11796027179454788 0.24843077314231216; … ; 0.3787652910883074 -0.20941141255949852 … 0.05853407716166453 -0.3559830102463321; -0.3322871532558658 0.010956838126639945 … -0.0397603313093719 -0.3690543850756196], [0.21179488994452117, -0.26308235869248064, -0.31050510370255446, -0.0013429707388629697, 0.302039812517866, 0.07042908382885543, 0.09387336893440462, -0.15459585641884832, 0.20210801344227153, -0.021141347670115884, -0.3270614652550431, 0.26236477071019554, -0.10176443894892845, -0.09559187370397054, -0.06184232890209135, 0.13584595587559106, 0.04644622008097893, 0.3816836071273813, 0.017325958713112488, -0.3895499686994527], BetaML.Utils.relu, BetaML.Utils.drelu)
julia> l2 = DenseLayer(20,20,f=relu)BetaML.Nn.DenseLayer{typeof(BetaML.Utils.relu), typeof(BetaML.Utils.drelu), Float64}([0.11528174247457718 0.15467891724960753 … -0.21498638269897663 0.3342856828575497; 0.1341841769686562 0.2900536727833864 … 0.06386991902442785 -0.10132571376107358; … ; -0.22072201771469038 0.3160172582728677 … 0.2817047318969575 -0.022749144125577325; -0.30848137438320844 -0.002871184962475626 … -0.0302479826549637 -0.28385297611020044], [-0.3475167897752178, -0.06427533045599143, 0.1144111267099352, -0.14574279105687218, -0.34980583062014725, 0.3369412547357659, 0.2917615809115682, -0.3118578859708546, 0.14811834280513053, -0.07514683694587143, 0.06427199383488846, -0.3727428912397876, -0.037577957118294536, -0.19790283329934125, 0.20083575441686846, -0.17296483308188398, -0.08783287496796671, -0.3645014751844286, 0.020501987420858625, -0.1253739053546974], BetaML.Utils.relu, BetaML.Utils.drelu)
julia> l3 = DenseLayer(20,1,f=relu) # y is positiveBetaML.Nn.DenseLayer{typeof(BetaML.Utils.relu), typeof(BetaML.Utils.drelu), Float64}([-0.5139582870961711 0.2515816841107448 … -0.24183664318944853 0.5160023473559416], [-0.0383979334333186], BetaML.Utils.relu, BetaML.Utils.drelu)
julia> mynn = NeuralNetworkEstimator(layers=[l1,l2,l3],loss=squared_cost, batch_size=6,epochs=600)NeuralNetworkEstimator - A Feed-forward neural network (unfitted)

Training...

julia> fit!(mynn,fit!(Scaler(),xtrain),ytrain)***
*** Training  for 600 epochs with algorithm BetaML.Nn.ADAM.
Training.. 	 avg loss on epoch 1 (1): 	 13242.153668056468
Training.. 	 avg loss on epoch 60 (60): 	 1181.6928466321904

Training the Neural Network...  11%|██▎                  |  ETA: 0:00:08Training.. 	 avg loss on epoch 120 (120): 	 1088.3494941178749

Training the Neural Network...  22%|████▋                |  ETA: 0:00:07Training.. 	 avg loss on epoch 180 (180): 	 1013.3049635384854

Training the Neural Network...  33%|███████              |  ETA: 0:00:06Training.. 	 avg loss on epoch 240 (240): 	 951.3686606709203

Training the Neural Network...  44%|█████████▍           |  ETA: 0:00:05Training.. 	 avg loss on epoch 300 (300): 	 896.3859468269703

Training the Neural Network...  56%|███████████▊         |  ETA: 0:00:04Training.. 	 avg loss on epoch 360 (360): 	 848.1831730368821

Training the Neural Network...  67%|██████████████       |  ETA: 0:00:03Training.. 	 avg loss on epoch 420 (420): 	 803.0116879738476

Training the Neural Network...  78%|████████████████▍    |  ETA: 0:00:02Training.. 	 avg loss on epoch 480 (480): 	 758.826111015261

Training the Neural Network...  89%|██████████████████▊  |  ETA: 0:00:01Training.. 	 avg loss on epoch 540 (540): 	 715.0438800829188
Training.. 	 avg loss on epoch 600 (600): 	 670.2162748919541

Training the Neural Network... 100%|█████████████████████| Time: 0:00:09
Training of 600 epoch completed. Final epoch error: 670.2162748919541.
354-element Vector{Float64}:
  81.07993883269886
 169.52126425092166
 250.79108097938965
  78.16895211050192
  89.16639780996218
 331.90901693553394
 173.84432553247282
 133.39834850071503
 102.83781661481322
 191.34828652268124
   ⋮
 116.23471219074186
 243.4532609806199
 120.09366004175695
  72.32392584531328
 190.70308709894002
  73.43313498447
 208.6428432763513
  80.71031166193289
 190.46283876877766
julia> ŷtrain = predict(mynn, fit!(Scaler(),xtrain))354-element Vector{Float64}: 81.07993883269886 169.52126425092166 250.79108097938965 78.16895211050192 89.16639780996218 331.90901693553394 173.84432553247282 133.39834850071503 102.83781661481322 191.34828652268124 ⋮ 116.23471219074186 243.4532609806199 120.09366004175695 72.32392584531328 190.70308709894002 73.43313498447 208.6428432763513 80.71031166193289 190.46283876877766
julia> ŷval = predict(mynn, fit!(Scaler(),xval))88-element Vector{Float64}: 71.82752547531973 100.52065565664535 169.81942366980454 99.10627565466231 87.07088916677233 205.4237879514058 93.0490540236944 129.87643420109129 163.7188570059829 56.560107713886154 ⋮ 269.3732741074357 163.3133668954572 131.64974801265325 98.42809688769717 260.518375060938 186.67582407651392 156.93316344977902 124.98890200819672 73.919331621937
julia> trainRME = relative_mean_error(ytrain,ŷtrain)0.18797770788831789
julia> testRME = relative_mean_error(yval,ŷval)0.32925719847819684
julia> plot(info(mynn)["loss_per_epoch"][10:end])Plot{Plots.GRBackend() n=1}
julia> savefig("loss_per_epoch.svg");

julia> scatter(yval,ŷval,xlabel="obs",ylabel="est",legend=nothing)Plot{Plots.GRBackend() n=1}
julia> savefig("obs_vs_est.svg");

Convolutional neural networks

julia> using LinearAlgebra, Statistics,Flux, MLDatasets, PlotsWARNING: using Flux.ADAM in module Main conflicts with an existing identifier.
julia> x_train, y_train = MLDatasets.MNIST(split=:train)[:](features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], targets = [5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8])
julia> x_train = permutedims(x_train,(2,1,3)); # For correct img axis #x_train = convert(Array{Float32,3},x_train);
julia> x_train = reshape(x_train,(28,28,1,60000));
julia> y_train = Flux.onehotbatch(y_train, 0:9)10×60000 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
julia> train_data = Flux.Data.DataLoader((x_train, y_train), batchsize=128) #x_test, y_test = MLDatasets.MNIST.testdata(dir = "data/MNIST")469-element DataLoader(::Tuple{Array{Float32, 4}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=128) with first element: (28×28×1×128 Array{Float32, 4}, 10×128 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)
julia> x_test, y_test = MLDatasets.MNIST(split=:test)[:](features = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], targets = [7, 2, 1, 0, 4, 1, 4, 9, 5, 9 … 7, 8, 9, 0, 1, 2, 3, 4, 5, 6])
julia> x_test = permutedims(x_test,(2,1,3)); # For correct img axis #x_test = convert(Array{Float32,3},x_test);
julia> x_test = reshape(x_test,(28,28,1,10000));
julia> y_test = Flux.onehotbatch(y_test, 0:9)10×10000 OneHotMatrix(::Vector{UInt32}) with eltype Bool: ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ … 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
julia> model = Chain( # 28x28 => 14x14 Conv((5, 5), 1=>8, pad=2, stride=2, relu), # 14x14 => 7x7 Conv((3, 3), 8=>16, pad=1, stride=2, relu), # 7x7 => 4x4 Conv((3, 3), 16=>32, pad=1, stride=2, relu), # 4x4 => 2x2 Conv((3, 3), 32=>32, pad=1, stride=2, relu), # Average pooling on each width x height feature map GlobalMeanPool(), Flux.flatten, Dense(32, 10), Flux.softmax )Chain( Conv((5, 5), 1 => 8, relu, pad=2, stride=2), # 208 parameters Conv((3, 3), 8 => 16, relu, pad=1, stride=2), # 1_168 parameters Conv((3, 3), 16 => 32, relu, pad=1, stride=2), # 4_640 parameters Conv((3, 3), 32 => 32, relu, pad=1, stride=2), # 9_248 parameters GlobalMeanPool(), Flux.flatten, Dense(32 => 10), # 330 parameters NNlib.softmax, ) # Total: 10 arrays, 15_594 parameters, 62.445 KiB.
julia> myaccuracy(y,ŷ) = (mean(Flux.onecold(ŷ) .== Flux.onecold(y)))myaccuracy (generic function with 1 method)
julia> myloss(x, y) = Flux.crossentropy(model(x), y)myloss (generic function with 1 method)
julia> opt = Flux.ADAM()Flux.Optimise.Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())
julia> ps = Flux.params(model)Params([Float32[-0.018152611 0.006144013 … 0.07216795 -0.005204156; 0.1157413 0.0034523 … 0.15868984 -0.026420452; … ; 0.06958171 0.011755511 … -0.15139769 0.030554285; 0.07758328 0.15696007 … -0.060031958 0.08738269;;;; -0.12388115 -0.15433629 … -0.080769666 -0.11831352; 0.054691378 -0.006606544 … 0.13629366 -0.12542726; … ; -0.13595605 -0.10628182 … -0.16323712 0.10119326; 0.078227885 -0.10164329 … 0.14460853 -0.11653691;;;; -0.09682111 -0.015498412 … -0.03326568 0.010851961; -0.07175437 -0.1594095 … 0.016408077 0.052282423; … ; 0.015065023 0.10757663 … -0.07766208 -0.09307505; 0.039841913 0.036572017 … 0.041785304 0.11531162;;;; -0.15386042 -0.09036203 … 0.042456828 0.12353006; 0.030619908 -0.12871708 … 0.020958906 0.04912195; … ; 0.0025037609 0.09905093 … -0.055252627 0.15756221; -0.10785886 -0.1612911 … 0.11653529 -0.100428075;;;; 0.07088187 0.1265084 … -0.15850636 0.016881566; 0.049683798 -0.012196512 … -0.15962097 -0.071222775; … ; 0.006804755 0.14638916 … -0.00025686438 0.050330408; -0.0131939715 -0.023252543 … 0.058404963 -0.16002569;;;; 0.03290338 -0.06764505 … -0.031487677 0.1444848; -0.047155295 -0.025888288 … -0.1550767 -0.032412603; … ; 0.12242135 0.018770762 … 0.04190921 0.15085743; -0.015036465 -0.115719125 … 0.0042420486 0.028330496;;;; -0.038776964 -0.08065962 … -0.02896085 0.14353941; -0.15696308 0.061981928 … 0.090264596 -0.12601264; … ; -0.020917734 0.047871325 … -0.15260118 0.05398491; -0.10640416 -0.074416906 … -0.051806558 0.14438131;;;; 0.032787886 -0.019967131 … 0.12576917 0.00078324653; 0.15558222 -0.040002808 … 0.07533672 -0.10722948; … ; 0.059403103 -0.11642933 … 0.15632387 0.034292843; -0.12007227 -0.08510876 … -0.13358568 -0.0254907], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.13106778 0.04260254 0.12787274; -0.12877432 -0.039598428 -0.010098835; -0.058088146 0.0123575535 0.023774406;;; 0.020659905 -0.13794404 0.024440765; -0.053472757 0.047383428 0.03047367; 0.025249403 0.017739475 -0.013797581;;; -0.067987524 0.12087393 -0.029808423; 0.16619632 -0.095047355 -0.16409695; -0.054571133 -0.12740646 0.037572306;;; 0.16574495 0.047239542 0.09880813; -0.052356146 -0.06208434 0.01847748; 0.041827857 0.111944936 0.118553005;;; 0.06397782 0.059345268 0.12896699; 0.14358984 0.0924654 -0.037780147; 0.15108088 0.038083535 0.04472536;;; 0.14055018 0.007890245 -0.118737325; 0.053580742 0.1260685 0.12618034; 0.096543714 -0.122716844 -0.023015022;;; 0.0038904946 0.103889786 -0.034101307; -0.10130682 -0.12196718 0.14354438; 0.15002891 0.129341 0.09623186;;; -0.033015966 -0.07473427 -0.16587512; -0.12961058 0.07038754 -0.12065413; 0.13858834 0.12936679 0.1486207;;;; 0.08434099 0.039107047 0.09098419; 0.009504637 0.15282285 -0.14057189; -0.020195186 -0.074677885 0.102352224;;; -0.14433172 0.06673966 0.14383006; 0.014366488 -0.11770517 0.033340335; 0.06347557 0.116970226 0.09354214;;; 0.035627007 0.014336109 -0.08269461; 0.067756854 -0.14976266 -0.14604722; 0.096520424 -0.15150937 -0.09312928;;; -0.10220482 -0.13832307 -0.0041528544; -0.04905337 0.023551583 0.13765667; -0.018013975 -0.06985693 0.15365462;;; -0.0037988424 -0.13569437 0.11241035; -0.16408399 -0.14139098 -0.08270502; 0.12870865 0.059226416 0.08900927;;; -0.10950047 -0.0324773 0.1236228; -0.0054243407 0.14874226 -0.03816744; 0.13996534 -0.1127795 0.011276603;;; 0.035244763 0.05288152 0.11224288; -0.14983551 -0.0027878087 0.1401723; -0.10161279 0.119226836 -0.15322721;;; 0.09451096 0.16434287 0.057527304; -0.099620365 0.12091114 0.10578762; -0.14705688 -0.033605974 -0.1354318;;;; 0.14422354 0.078017436 0.05923772; -0.09947026 -0.12802373 -0.030246159; -0.12051871 0.14463237 -0.02180469;;; 0.14299601 0.07554849 0.05405913; -0.004779498 -0.12927511 -0.030967474; -0.14108336 0.13921326 -0.14247182;;; 0.14951569 -0.13501236 0.095104456; -0.04256632 0.07414107 -0.027139168; -0.1077772 -0.082829535 0.035253268;;; -0.16202855 0.092351496 -0.05832855; 0.12165507 -0.054813683 -0.03950864; -0.11374404 0.058040123 0.01738596;;; 0.12040907 -0.11824876 0.1131899; -0.1664161 -0.16500525 -0.039083123; 0.017494857 0.044427138 0.14973104;;; -0.14513086 -0.15790413 -0.10782653; 0.15851155 0.116970584 -0.10501554; -0.13046388 -0.08183336 0.059732795;;; 0.14753525 0.006567995 0.05793349; 0.10162693 -0.1327919 0.16032416; -0.029001793 -0.15804367 -0.086291455;;; 0.08188494 -0.1171698 0.09266754; -0.11547917 0.049948454 -0.08523818; -0.12355735 0.037935913 -0.114006005;;;; … ;;;; -0.14937946 -0.020130157 0.10844984; 0.13240759 -0.0668236 -0.008484821; -0.036634862 0.16563752 -0.12576476;;; -0.1404281 0.0539278 -0.07722305; -0.12919647 0.009545168 0.15505312; 0.041422725 0.009710154 0.055159293;;; -0.0021284819 -0.1348242 -0.12099274; -0.13169 0.15475385 -0.032497328; -0.09083623 -0.15016636 0.16058373;;; -0.14663109 -0.023646574 -0.09040459; 0.13920079 0.040357273 -0.022262773; 0.05331161 0.14398718 0.116670035;;; 0.08373821 0.085005485 0.038989823; 0.15414321 0.03550913 -0.041018926; -0.13777487 0.021996062 -0.11685127;;; -0.12737119 0.09817078 -0.010300259; 0.03096531 -0.047542475 -0.12701324; 0.14620678 -0.018840095 -0.1603313;;; -0.147555 -0.039566815 -0.015281717; -0.056131423 -0.042851966 -0.029827058; -0.02237322 0.15162133 0.16137916;;; 0.06411433 0.14486435 -0.16383271; 0.1176672 0.10981381 0.031014603; -0.14293548 0.030487975 -0.10088688;;;; -0.027666708 0.035161696 0.05183216; -0.0026744804 -0.036467534 -0.015219927; 0.0044917865 0.025312623 -0.01500988;;; 0.037509065 -0.09693855 -0.100487016; 0.08657136 -0.08971767 -0.090995654; 0.06416335 0.004591286 0.013753991;;; -0.10666017 0.16487275 0.11441797; 0.003352046 0.016467672 -0.09230252; -0.15116602 0.076765716 -0.025712272;;; -0.107877456 -0.057508074 0.026152533; -0.046780866 -0.005206704 0.010184924; 0.11710614 -0.10885465 -0.09935588;;; 0.12481364 -0.046449702 -0.08240038; -0.11551708 0.14965574 0.087319955; 0.066731416 -0.01470542 -0.10835197;;; -0.15723239 -0.15040398 0.08789711; -0.066440724 -0.15297994 0.052080255; -0.090552986 -0.031231523 -0.12709491;;; -0.09567958 -0.10768853 -0.0785982; 0.084913276 0.017216384 0.01432703; 0.043805204 -0.03275041 -0.15180014;;; -0.06486057 -0.02418236 -0.029790282; 0.13575642 -0.16271755 -0.050636433; 0.09859149 0.07756504 0.13416189;;;; -0.10783325 -0.07139667 -0.026574612; 0.03575055 0.07273774 0.005440335; 0.097510636 -0.051362496 -0.10864131;;; 0.058163702 0.0704273 0.09650801; 0.107170805 0.046532452 -0.009908001; -0.056746047 -0.048868936 -0.014220595;;; 0.16476119 0.12368272 -0.104332864; -0.084661864 0.023342352 0.056025665; -0.09980099 0.12843591 -0.08769745;;; -0.16480947 0.043168664 0.03589795; 0.100167334 -0.1521303 0.09144076; 0.003683766 0.059761148 -0.123229146;;; 0.0725383 0.10868172 -0.111569405; -0.025916358 -0.030883312 -0.00733767; 0.010553619 0.12913947 -0.11002843;;; 0.07124971 -0.071254574 -0.009857615; -0.035561644 0.15081379 -0.1341024; -0.10937111 0.0403893 0.018813213;;; -0.14924839 0.01871711 -0.16162744; 0.058779597 -0.0060353875 -0.0750618; 0.016092181 -0.12470345 0.09361669;;; -0.13856225 0.028888505 0.06618538; -0.047956467 0.07762365 0.15556613; -0.02470245 0.031985186 -0.07368803], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.021639204 -0.020139351 -0.0038186028; -0.05652172 -0.10309672 0.029792484; -0.094071135 0.08029305 0.056195308;;; 0.07638925 0.06799201 -0.036138494; 0.01159804 0.0802046 -0.0923441; 0.02732574 0.10490917 -0.116831124;;; 0.028762091 0.07617421 0.023174051; -0.0537858 -0.025977014 0.0053681456; 0.037559923 0.08604192 -0.07326633;;; … ;;; -0.066789776 0.050570317 0.008096339; 0.046095 0.0007880056 -0.030330798; 0.11315493 -0.111830816 -0.079347596;;; 0.05761074 0.1158912 -0.10965486; -0.08436241 0.09314152 0.07514253; -0.054984666 -0.023113599 0.08824058;;; -0.01875303 0.0036040894 0.069420524; 0.11331141 -0.08688972 0.03510612; -0.015108184 0.11431741 -0.08095183;;;; -0.05694507 0.11131976 -0.052047692; 0.07005148 -0.012834362 0.05147261; -0.04829599 -0.053378828 0.019255476;;; -0.056799438 0.06843942 0.013728212; -0.06325173 0.05716252 0.10753636; -0.02802861 -0.06636374 -0.021997845;;; -0.0150548825 0.083151855 -0.05350933; 0.11683501 -0.11025819 -0.019691106; 0.020645157 -0.027207112 0.04983713;;; … ;;; -0.004065878 -0.039357472 -0.09219876; 0.038711403 0.08356002 0.11753153; -0.102099225 -0.036552403 -0.09833661;;; -0.05130172 -0.012670846 0.1001647; 0.010114485 -0.09939097 0.0741309; -0.023642372 0.05299792 -0.0014403826;;; -0.011729876 0.003157066 0.0051299175; -0.030588793 0.07235772 -0.056875527; 0.072894774 0.06213901 -0.059515383;;;; -0.0006812476 0.117374875 -0.024664842; 0.027266033 -0.030146573 0.0035082896; 0.097689345 0.07119848 0.09923286;;; 0.08373726 -0.02524887 0.054629594; 0.07014858 -0.025002101 -0.040787768; -0.05399808 -0.10697184 0.06548484;;; 0.03205227 -0.053554863 -0.10201076; 0.09868099 0.06007207 -0.033873156; 0.00731644 -0.012773516 -0.08403358;;; … ;;; 0.04152292 0.03267776 0.005841567; 0.06833472 0.042811308 -0.060463518; -0.09326883 0.03819415 0.10717109;;; -0.042212132 -0.05719545 -0.029738044; 0.012187787 -0.11232048 0.093978845; 0.010520977 0.060698472 0.04729937;;; 0.025038445 -0.026358232 0.11167481; -0.105175875 0.08690833 0.04290427; -0.09407916 -0.057811514 -0.016915005;;;; … ;;;; 0.11588894 0.071474254 9.3256924f-5; 0.054075938 0.028412905 0.02429416; -0.1117421 0.008200597 -0.103961036;;; -0.048813906 -0.10657591 -0.08597915; 0.09465189 0.022470832 0.09301891; 0.10597768 0.034267172 0.11226991;;; 0.062047593 0.09094691 0.09110253; 0.09897253 -0.07401651 -0.029808978; -0.008077331 0.0476103 0.10412882;;; … ;;; -0.083086684 0.008563973 0.11711251; 0.057589214 0.06655863 -0.0883546; 0.036343735 0.10150546 -0.08899552;;; 0.033861592 -0.112482026 0.015415435; 0.11411816 -0.02612745 -0.042293996; 0.012063131 0.10340301 0.08375159;;; -0.08168753 0.093652084 0.0927599; -0.07616901 -0.02633588 0.08312988; 0.02427282 0.034326587 -0.08092993;;;; 0.05265888 0.11146375 0.042169075; -0.024951285 0.057649907 0.030164288; 0.01593415 0.04238998 0.06288442;;; -0.08044354 -0.05683632 0.047289226; 0.049627688 0.018767247 0.07114087; -0.023673112 -0.10895375 -0.10492733;;; -0.077172324 0.11519706 -0.07375857; -0.055523247 0.07937613 -0.013517773; -0.07230524 0.008954393 0.0002735752;;; … ;;; 0.04060329 -0.00030947026 0.035776254; 0.09622907 0.031288866 0.07989035; 0.011229184 0.08034448 0.0855236;;; 0.008441367 0.058343727 -0.11084124; 0.005789867 -0.07686793 -0.052127786; -0.01851733 0.044820588 -0.031802326;;; -0.02514089 0.024741113 0.045285046; 0.05725218 0.0021065697 -0.116880275; -0.07701339 -0.014320572 -0.045630425;;;; -0.10605303 0.05160258 -0.055861544; 0.07334181 0.026933663 0.10554007; 0.09687907 -0.005894869 -0.010254679;;; 0.086900815 -0.09877596 0.005157945; 0.101670705 0.039182648 0.06479321; 0.017977625 0.05547083 0.058554225;;; -0.016507909 -0.021959815 -0.07479041; -0.048424453 -0.117763706 -0.0032202299; 0.02519689 -0.114866555 0.05149752;;; … ;;; -0.004732501 -0.04442268 -0.05560372; 0.09739849 0.049187046 -0.019724458; -0.03469037 0.11310527 -0.056564007;;; 0.09443447 0.10403125 0.09339323; 0.09306533 0.10919626 -0.11271033; -0.10083372 -0.10062781 -0.04870459;;; -0.115385965 -0.014143457 -0.026151543; 0.038125254 -0.03537877 -0.029913994; 0.048329186 -0.040573113 -0.079141594], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.09521355 -0.045352295 0.074363194; 0.030879619 0.04321872 -0.05433817; 0.040094994 0.060028616 -0.093464494;;; -0.069455266 -0.08639323 -0.04823107; 0.040610876 -0.009374358 -0.003842198; -0.088014275 -0.029118678 -0.06586106;;; 0.08479299 0.056908343 0.07187209; 0.056806885 -0.024134642 0.048604514; 0.017163496 -0.023445847 -0.06557478;;; … ;;; 0.09807663 0.02440124 0.002177775; -0.09985937 0.05063159 0.028522117; 0.0862969 0.019852603 -0.061208863;;; 0.006287556 -0.07436922 -0.031243088; -0.057585314 -0.099454746 0.08916432; -0.035164323 0.05063254 -0.06161751;;; 0.030428598 0.09401016 -0.03475013; -0.010298519 -0.01648065 0.098101325; 0.09677442 -0.09137118 -0.03324263;;;; 0.07917295 -0.032214094 -0.055153195; -0.0073085045 0.05259292 -0.037263986; 0.101941064 -0.092696674 -0.004439476;;; -0.067510426 0.062202547 0.052424144; -0.023820022 -0.059652407 0.09486873; -0.041645158 -0.023190964 -0.0017838277;;; -0.02674397 0.04710382 0.0015214396; 0.09795784 0.06606612 -0.070143245; -0.09238404 0.052655213 -0.075048834;;; … ;;; 0.057324044 -0.08205201 -0.017451046; -0.082589656 0.048709903 -0.013693711; -0.08406959 -0.08386342 -0.09628208;;; -0.019624539 -0.06942526 0.044040803; -0.083278775 0.09475217 0.05669203; 0.031525332 -0.1014056 0.023456687;;; -0.064479284 0.09398863 -0.052063204; 0.044773363 0.04589448 -0.06922457; 0.067504704 0.032250628 -0.030848132;;;; 0.0063233506 -0.022690218 -0.029833363; -0.05115094 0.06298902 -0.065856636; 0.039255597 -0.06786948 -0.053604465;;; -0.082095265 0.042757954 0.08876439; 0.009451057 -0.09297825 0.03078085; 0.06489167 0.0342565 -0.09370339;;; 0.08631014 0.07782066 -0.06040991; 0.07702739 -0.06773697 0.09145798; 0.08431833 0.07340191 0.101797365;;; … ;;; -0.043516245 -0.07531561 0.052184895; -0.033508502 0.015441123 -0.05536116; 0.062683545 0.05567841 -0.039573867;;; -0.041163843 -0.05219811 0.08518452; -0.0787319 -0.005724224 0.051473055; 0.06186779 0.006059661 -0.03760877;;; 0.080989085 0.00207308 -0.036801405; -0.06289522 -0.096902125 -0.048052385; 0.028793143 0.08139572 -0.00867118;;;; … ;;;; -0.0026022482 0.094901934 0.05802486; -0.00891522 0.043404896 0.07334483; -0.026947046 0.010061864 0.03479432;;; 0.07173874 0.034192577 0.009090215; -0.054086413 0.010674423 0.093670376; -0.055625834 0.08217817 -0.07749219;;; -0.02952271 -0.05012594 0.07656789; -0.0128659345 -0.025107045 0.07007196; -0.005534033 0.058048088 -0.06671378;;; … ;;; -0.06545792 -0.10145813 0.0039959126; -0.024580626 -0.013747087 -0.0508731; -0.0456478 0.04884522 0.07228288;;; 0.08769625 0.100163884 -0.0659281; 0.012936221 -0.09688549 -0.087819465; 0.00711453 -0.0049346383 0.0062191426;;; -0.016452825 -0.046900477 0.050309744; 0.003389242 -0.040618822 0.083822615; 0.04221258 0.023290368 0.067289814;;;; 0.007049973 -0.06390976 -0.0218973; -0.01583535 0.08300002 0.0992459; -0.06158764 -0.06272069 -0.03164105;;; -0.042273182 -0.07814096 0.08653356; 0.081806414 -0.09842929 -0.06056298; -0.0022124622 0.039823443 0.04762538;;; -0.052399214 0.016657457 0.046748184; 0.0491367 0.019872144 0.07854553; 0.043430217 0.06110014 0.09088179;;; … ;;; 0.047541276 0.0065401867 -0.05626988; -0.011600641 0.012623962 -0.0796054; 0.047658257 0.08267913 0.0686674;;; -0.06188273 0.019908937 -0.027706763; 0.022904096 0.09697838 -0.087783195; 0.039484173 0.08177439 -0.07058952;;; 0.053340532 -0.061547987 -0.06856716; -0.053010676 0.013946305 -0.0124774985; 0.030862195 -0.040949147 0.064876616;;;; 0.05927289 0.056843933 0.04979509; -0.06720981 -0.029098492 -0.081318684; -0.042591657 0.040632084 0.051601365;;; -0.040178712 -0.07033689 0.002815422; -0.014813551 0.09928537 0.084115975; 0.03668671 0.093823604 -0.09091253;;; 0.08864864 0.039345253 0.08731895; -0.036243744 0.024119215 -0.05173475; -0.04481978 0.054022186 0.03546693;;; … ;;; 0.039870895 0.019572562 0.07139827; -0.015156336 0.08925908 -0.0046456535; -0.06457358 0.007713134 -0.08849605;;; 0.07295248 -0.01935732 0.10038343; -0.05371574 -0.042433612 -0.026204802; 0.007937598 0.02579303 -0.076352075;;; -0.017264737 0.07721693 0.017489858; 0.01762458 0.0937688 0.09741045; -0.03997128 0.026067574 0.036030255], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.17025222 -0.04319071 … -0.32852468 0.09128798; -0.21834598 0.07442729 … 0.0739614 -0.22189385; … ; -0.22692499 -0.30677947 … 0.21752694 -0.16095294; 0.05587981 -0.008666915 … -0.27115506 -0.35247183], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
julia> number_epochs = 44
julia> [(println(e); Flux.train!(myloss, ps, train_data, opt)) for e in 1:number_epochs]1 2 3 4 4-element Vector{Nothing}: nothing nothing nothing nothing
julia> ŷtrain = model(x_train)10×60000 Matrix{Float32}: 8.55891f-6 0.999971 5.24617f-7 … 0.0249809 0.00468651 1.09591f-6 1.56651f-10 6.93928f-7 9.30811f-7 7.3349f-5 0.000147928 2.78294f-5 4.22042f-5 0.00147321 0.00522841 0.217504 1.77461f-8 0.00080586 6.1435f-6 5.46034f-5 5.89417f-12 2.46721f-11 0.988247 0.000130771 9.6599f-7 0.782199 2.43885f-7 4.9961f-7 … 0.00123264 0.000823258 5.16558f-11 1.09386f-8 4.06079f-8 0.972072 4.7233f-7 0.000125487 6.37032f-8 0.00709748 1.53421f-5 0.00156628 4.72028f-6 3.64378f-7 8.85473f-6 8.33064f-5 0.986323 9.11343f-6 1.1581f-7 0.00379671 4.43461f-6 0.00124269
julia> ŷtest = model(x_test)10×10000 Matrix{Float32}: 5.31574f-6 5.40955f-5 6.58353f-5 … 1.41547f-6 2.3038f-5 1.74953f-7 2.66305f-5 0.996384 1.57631f-7 3.50559f-13 4.21994f-5 0.990211 0.00040111 8.0285f-8 3.86899f-7 7.06879f-5 0.00956708 0.000314679 5.89519f-6 2.09542f-10 6.73878f-9 2.93683f-14 0.000500045 1.71393f-8 2.15927f-9 2.53342f-6 0.000139129 0.000193413 … 0.999953 2.49671f-6 3.63248f-15 8.44003f-9 0.000274142 2.47801f-7 0.999974 0.99978 1.44824f-9 0.000937896 1.49648f-8 2.17516f-12 4.31369f-6 1.89597f-6 0.000907781 3.89678f-5 2.35955f-9 9.47277f-5 2.03592f-12 2.16021f-5 6.92616f-7 2.02211f-10
julia> myaccuracy(y_train,ŷtrain)0.9676833333333333
julia> myaccuracy(y_test,ŷtest)0.9674
julia> plot(Gray.(x_train[:,:,1,2]))Plot{Plots.GRBackend() n=1}
julia> cm = ConfusionMatrix()A BetaML.Utils.ConfusionMatrix BetaMLModel (unfitted)
julia> fit!(cm,Flux.onecold(y_test) .-1, Flux.onecold(ŷtest) .-1 )10×10 Matrix{Float64}: 0.97179 0.0116732 0.0077821 … 0.000972763 0.000972763 0.00775194 0.959302 0.00387597 0.0116279 0.00484496 0.0 0.00176211 0.994714 0.000881057 0.000881057 0.00102041 0.00204082 0.0 0.0 0.00102041 0.00509165 0.00305499 0.00101833 0.0 0.00203666 0.0109019 0.0 0.00792864 … 0.00396432 0.0 0.00112108 0.00112108 0.0 0.0044843 0.00224215 0.0 0.00104384 0.00417537 0.0 0.0 0.00693069 0.0039604 0.000990099 0.972277 0.0 0.00616016 0.00821355 0.0102669 0.0154004 0.919918
julia> println(cm)A BetaML.Utils.ConfusionMatrix BetaMLModel (fitted) ----------------------------------------------------------------- *** CONFUSION MATRIX *** Scores actual (rows) vs predicted (columns): 11×11 Matrix{Any}: "Labels" "7" "2" "1" "0" "4" "9" "5" "6" "3" "8" "7" 999 12 8 1 1 5 0 0 1 1 "2" 8 990 4 6 0 0 4 3 12 5 "1" 0 2 1129 0 0 0 1 1 1 1 "0" 1 2 0 968 0 0 4 4 0 1 "4" 5 3 1 3 943 17 1 7 0 2 "9" 11 0 8 7 8 969 2 0 4 0 "5" 1 1 0 1 0 6 877 0 4 2 "6" 0 1 4 12 3 0 17 921 0 0 "3" 7 4 1 0 0 8 8 0 982 0 "8" 6 8 10 8 4 7 15 5 15 896 Normalised scores actual (rows) vs predicted (columns): 11×11 Matrix{Any}: "Labels" "7" "2" "1" "0" "4" "9" "5" "6" "3" "8" "7" 0.97179 0.0116732 0.0077821 0.000972763 0.000972763 0.00486381 0.0 0.0 0.000972763 0.000972763 "2" 0.00775194 0.959302 0.00387597 0.00581395 0.0 0.0 0.00387597 0.00290698 0.0116279 0.00484496 "1" 0.0 0.00176211 0.994714 0.0 0.0 0.0 0.000881057 0.000881057 0.000881057 0.000881057 "0" 0.00102041 0.00204082 0.0 0.987755 0.0 0.0 0.00408163 0.00408163 0.0 0.00102041 "4" 0.00509165 0.00305499 0.00101833 0.00305499 0.960285 0.0173116 0.00101833 0.00712831 0.0 0.00203666 "9" 0.0109019 0.0 0.00792864 0.00693756 0.00792864 0.960357 0.00198216 0.0 0.00396432 0.0 "5" 0.00112108 0.00112108 0.0 0.00112108 0.0 0.00672646 0.983184 0.0 0.0044843 0.00224215 "6" 0.0 0.00104384 0.00417537 0.0125261 0.00313152 0.0 0.0177453 0.961378 0.0 0.0 "3" 0.00693069 0.0039604 0.000990099 0.0 0.0 0.00792079 0.00792079 0.0 0.972277 0.0 "8" 0.00616016 0.00821355 0.0102669 0.00821355 0.00410678 0.00718686 0.0154004 0.00513347 0.0154004 0.919918 *** CONFUSION REPORT *** - Accuracy: 0.9674 - Misclassification rate: 0.03259999999999996 - Number of classes: 10 N Class precision recall specificity f1score actual_count predicted_count TPR TNR support 1 7 0.962 0.972 0.996 0.967 1028 1038 2 2 0.968 0.959 0.996 0.964 1032 1023 3 1 0.969 0.995 0.996 0.982 1135 1165 4 0 0.962 0.988 0.996 0.975 980 1006 5 4 0.983 0.960 0.998 0.972 982 959 6 9 0.958 0.960 0.995 0.959 1009 1012 7 5 0.944 0.983 0.994 0.963 892 929 8 6 0.979 0.961 0.998 0.970 958 941 9 3 0.964 0.972 0.996 0.968 1010 1019 10 8 0.987 0.920 0.999 0.952 974 908 - Simple avg. 0.968 0.967 0.996 0.967 - Weighted avg. 0.968 0.967 0.996 0.967 ----------------------------------------------------------------- Output of `info(cm)`: - mean_precision: (0.9675566751424146, 0.9676902553907551) - fitted_records: 10000 - specificity: [0.9956531431119037, 0.9963202497769849, 0.9959390862944163, 0.9957871396895787, 0.9982257706808605, 0.9952174396618841, 0.9942907334211682, 0.997788099977881, 0.9958843159065629, 0.9986705074230002] - precision: [0.9624277456647399, 0.967741935483871, 0.9690987124463519, 0.9622266401590457, 0.9833159541188738, 0.9575098814229249, 0.9440258342303552, 0.9787460148777896, 0.9636898920510304, 0.986784140969163] - misclassification: 0.03259999999999996 - mean_recall: (0.9670959707826225, 0.9674) - n_categories: 10 - normalised_scores: [0.9717898832684825 0.011673151750972763 0.007782101167315175 0.0009727626459143969 0.0009727626459143969 0.0048638132295719845 0.0 0.0 0.0009727626459143969 0.0009727626459143969; 0.007751937984496124 0.9593023255813954 0.003875968992248062 0.005813953488372093 0.0 0.0 0.003875968992248062 0.0029069767441860465 0.011627906976744186 0.0048449612403100775; 0.0 0.001762114537444934 0.9947136563876652 0.0 0.0 0.0 0.000881057268722467 0.000881057268722467 0.000881057268722467 0.000881057268722467; 0.0010204081632653062 0.0020408163265306124 0.0 0.9877551020408163 0.0 0.0 0.004081632653061225 0.004081632653061225 0.0 0.0010204081632653062; 0.0050916496945010185 0.003054989816700611 0.0010183299389002036 0.003054989816700611 0.960285132382892 0.017311608961303463 0.0010183299389002036 0.007128309572301426 0.0 0.002036659877800407; 0.010901883052527254 0.0 0.007928642220019821 0.006937561942517344 0.007928642220019821 0.9603567888999008 0.0019821605550049554 0.0 0.003964321110009911 0.0; 0.0011210762331838565 0.0011210762331838565 0.0 0.0011210762331838565 0.0 0.006726457399103139 0.9831838565022422 0.0 0.004484304932735426 0.002242152466367713; 0.0 0.0010438413361169101 0.0041753653444676405 0.012526096033402923 0.003131524008350731 0.0 0.017745302713987474 0.9613778705636743 0.0 0.0; 0.006930693069306931 0.0039603960396039604 0.0009900990099009901 0.0 0.0 0.007920792079207921 0.007920792079207921 0.0 0.9722772277227723 0.0; 0.006160164271047228 0.008213552361396304 0.01026694045174538 0.008213552361396304 0.004106776180698152 0.007186858316221766 0.015400410677618069 0.00513347022587269 0.015400410677618069 0.919917864476386] - tn: [8933, 8935, 8829, 8982, 9002, 8948, 9056, 9022, 8953, 9014] - mean_f1score: (0.9671082295432092, 0.9673328866899091) - actual_count: [1028, 1032, 1135, 980, 982, 1009, 892, 958, 1010, 974] - accuracy: 0.9674 - recall: [0.9717898832684825, 0.9593023255813954, 0.9947136563876652, 0.9877551020408163, 0.960285132382892, 0.9603567888999008, 0.9831838565022422, 0.9613778705636743, 0.9722772277227723, 0.919917864476386] - f1score: [0.9670861568247822, 0.9635036496350365, 0.9817391304347826, 0.9748237663645518, 0.9716640906749099, 0.9589312221672439, 0.9632070291048874, 0.9699842022116903, 0.9679645145391819, 0.9521785334750266] - mean_specificity: (0.996377648594424, 0.9963764859442402) - predicted_count: [1038, 1023, 1165, 1006, 959, 1012, 929, 941, 1019, 908] - scores: [999 12 8 1 1 5 0 0 1 1; 8 990 4 6 0 0 4 3 12 5; 0 2 1129 0 0 0 1 1 1 1; 1 2 0 968 0 0 4 4 0 1; 5 3 1 3 943 17 1 7 0 2; 11 0 8 7 8 969 2 0 4 0; 1 1 0 1 0 6 877 0 4 2; 0 1 4 12 3 0 17 921 0 0; 7 4 1 0 0 8 8 0 982 0; 6 8 10 8 4 7 15 5 15 896] - tp: [999, 990, 1129, 968, 943, 969, 877, 921, 982, 896] - fn: [29, 42, 6, 12, 39, 40, 15, 37, 28, 78] - categories: [7, 2, 1, 0, 4, 9, 5, 6, 3, 8] - fp: [39, 33, 36, 38, 16, 43, 52, 20, 37, 12]
julia> res = info(cm)Dict{String, Any} with 21 entries: "mean_precision" => (0.967557, 0.96769) "fitted_records" => 10000 "specificity" => [0.995653, 0.99632, 0.995939, 0.995787, 0.998226, 0.99… "precision" => [0.962428, 0.967742, 0.969099, 0.962227, 0.983316, 0.9… "misclassification" => 0.0326 "mean_recall" => (0.967096, 0.9674) "n_categories" => 10 "normalised_scores" => [0.97179 0.0116732 … 0.000972763 0.000972763; 0.007751… "tn" => [8933, 8935, 8829, 8982, 9002, 8948, 9056, 9022, 8953,… "mean_f1score" => (0.967108, 0.967333) "actual_count" => [1028, 1032, 1135, 980, 982, 1009, 892, 958, 1010, 974] "accuracy" => 0.9674 "recall" => [0.97179, 0.959302, 0.994714, 0.987755, 0.960285, 0.96… "f1score" => [0.967086, 0.963504, 0.981739, 0.974824, 0.971664, 0.9… "mean_specificity" => (0.996378, 0.996376) "predicted_count" => [1038, 1023, 1165, 1006, 959, 1012, 929, 941, 1019, 90… "scores" => [999 12 … 1 1; 8 990 … 12 5; … ; 7 4 … 982 0; 6 8 … 15… "tp" => [999, 990, 1129, 968, 943, 969, 877, 921, 982, 896] "fn" => [29, 42, 6, 12, 39, 40, 15, 37, 28, 78] ⋮ => ⋮
julia> heatmap(string.(res["categories"]),string.(res["categories"]),res["normalised_scores"],seriescolor=cgrad([:white,:blue]),xlabel="Predicted",ylabel="Actual", title="Confusion Matrix (normalised scores)")Plot{Plots.GRBackend() n=1}
julia> savefig("cm_digits.svg")"/home/runner/work/SPMLJ/SPMLJ/buildedDoc/04_-_NN_-_Neural_Networks/cm_digits.svg"

View this file on Github.


This page was generated using Literate.jl.