An introduction to machine learning with Keras in R

Share Tweet

A guest post by @MaxMaPichler, MSc student in the Group for Theoretical Ecology / UR

Artificial neural networks, especially deep neural networks and (deep) convolutions neural networks, have become increasingly popular in recent years, dominating most machine learning competitions since the early 2010’s (for reviews about DNN and (D)CNNs see LeCun, Bengio, & Hinton, 2015). In ecology, there are a large number of potential applications for these methods, for example image recognition, analysis of acoustic signals, or any other type of classification tasks for which large datasets are available.

Fig. 1 shows the principle of a DNN – we have a number of input features (predictor variables) that are connected to one or several outputs through several hidden layers of “neurons”. The different layers are connected, so that a large value in a previous layer will create corresponding values in the next, depending on the strength of the connection. The latter is learned / trained by adjusting connections / weights to produce a good fit on the training data.

Fig. 1: An visualization of a deep neural network, created with https://playground.tensorflow.org/ 

Building DNNs with Keras in R

So, how does one build these kind of models in R? A particularly convenient way is the Keras implementation for R, available since September 2017. Keras is essentially a high-level wrapper that makes the use of other machine learning frameworks more convenient. Tensorflow, theano, or CNTK can be used as backend. As a result, we can create an ANN with n hidden layers in a few lines of code.

As an example, here a deep neural networks, fitted on the iris data set (the data consists of three iris species classes, each with 50 samples of four describing features). We scale the input variables to range (0,1) and “one hot” (=dummy features) encode the response variable.  In the output layer, we define three nodes, for each class one. We use the softmax activation function to normalize the output for each node and the ∑ of outputs to range 0,1. For a evaluation of the model quality,  keras will split the data in a training and a validation set. The code in Keras is as follows:

Fig. 2: accuracy of the algorithm for training and validation data. 

DNN with dropout

A common concern in this type of networks is overfitting (error on test data deviates considerably from training error).  We want our model to achieve a high generalization (low test error). There are several ways for regularization, such as introducing weight penalties (e.g. L1, L2), early stopping, weight decay.

The dropout method is one simple and efficient way to regularize our model. Dropout means that nodes and their connections will be randomly dropped with probability p during training. This way an ensemble of thinned sub networks will be trained and averaged for predictions (see Srivastava et. al., 2014 for a detailed explanation).

Fig. 3: accuracy of the algorithm for training and validation data.  The lower accuracy for the training data is because Keras does not correct for the dropouts, but the final accuracy is identical to the previous case in this simple example. Overfitting becomes more important in larger datasets with more predictors.

Concluding remarks

There is no overall rule for how to set the network architecture (depth and width of layers). In general, the optimization gets harder with the depth of the network. Network parameters can be tuned, but be are of overfitting (i.e. implement an outer cross-validation).

So, what have we gained? In this case, we have applied the methods to a very simple example only, so benefits are limited. In general, however, DNNs are particularly useful where we have large datasets, and complex dependencies that cannot be fit with simpler, traditional statistical models.

The disadvantage is that we end up with a “black box model” that can predict, but is hard to interpret for inference. This topic has often named as one of the main problems of machine learning, and there is much research on new frameworks to address this issue (e.g. DALEX, lime, see also Staniak, M., & Biecek, P. (2018))

Useful links:

Rstudio tensorflow

Rstudio keras documentation

Share Tweet



Related articles


0 Comments