Food web prediction is non-trivial

Feature creation, engineering, and neural networks

In the previous entry, I wrote up a very bad example of how easy it is to make a bad job of predicting species interactions, and to be misled by measures into thinking we have done it right. In this entry, I will use Flux.jl to train a neural network on predicting a metaweb, based on the same data. At the end of the post, we will have a much better model, not because it achieves an impressive accuracy, but because we understand it and can decide whether we should trust it. This is going to be a long one (both in reading and run time!).

As a sidenote, running the code in this entry might take a little while. I have been running it on my venerably aged laptop, which doesn’t help. I have also not made any effort to make it fast, because this is intended as a discussion of how we can use ecology to inform some of the choice we are making, and so this is not something that I intend to use directly. So, caveat emptor, you are about to peek into work-in-progress code that is not idiomatic, not optimized, and just plain not good. But it gets the job done. It just needs more time to do it.

Setting up the environment

We will need a few things to make this project work. There is going to be the usual series of plotting packages; we will need the packages for networks manipulation; we will also need a few statistics packages, and finally Flux.jl to do the actual learning.

using Plots
using EcologicalNetworks
using Statistics, LinearAlgebra, Random, StatsBase
using MultivariateStats
using Flux

Random.seed!(13711)

Once this is all installed and loaded, we can start.

Getting the data

Species interactions

We will once again look at the data from the rodents-ectoparasites in Eurasia, and aggregate it as a metaweb. One thing we know about a metaweb is that it is incomplete, because some (host,parasite) pairs have never been observed. Let’s keep this in mind.

ids = map(i -> i.ID, filter(i -> contains(i.Reference, "Hadfield"), web_of_life()))
B = convert.(BipartiteNetwork, web_of_life.(ids))
M = reduce(∪, B)
206×121 bipartite  ecological network (Bool, String) (L: 2131)

In preparation for the next step, and specifically to avoid making the results of the next steps “all fucky” (it’s a technical term, look it up), we will convert the bipartite network into a unipartite network, and also mirror it; this implies that an interaction i → j will also result in an interaction j → i. I will not go into the specifics of why this is important, but it just makes the matrix a little less sparse.

N = convert(UnipartiteNetwork, M)
K = EcologicalNetworks.mirror(N)
327×327 unipartite  ecological network (Bool, String) (L: 4262)

We have a dataset K to use in extracting some features, and the actual dataset M on which we would like to make some predictions. To nake the predictions, we will extract some features from K.

Pseudo-traits we can make up because reality doesn’t matter

In the interest of speed and familiarity, we will do so using a Principal Component Analysis, and we will assume that the values on each PC for a species is akin to a trait - the top n PC will be the most important “traits” for this species. One of the advantages of this technique is that we will have traits in the same space for both hosts and parasites, but of course in real life we would explore better methods of projecting the graph into a new space.

pc = fit(PCA, Float64.(K.A))
pr = MultivariateStats.transform(pc, Float64.(K.A))
plot(
  scatter(pr[:,1], pr[:,2], frame=:zerolines, lab="", msw=0.0, c=:purple, alpha=0.5, xlim=(-1,1), ylim=(-1,1), aspectratio=1),
  scatter(pr[:,2], pr[:,3], frame=:zerolines, lab="", msw=0.0, c=:purple, alpha=0.5, xlim=(-1,1), ylim=(-1,1), aspectratio=1)
  )

This figure shows the projection of the 300+ species in the new space defined by the PCA (axis 1 and 2, and 2 and 3) - the first three axes do not explain a lot of variance, but we will use more than that anyways.

The main point of this step is that we can now define a feature vector for every interaction (or absence of interaction) that will contain the position of the two species in the PCA space. This is one way of (hopefully) capturing information about the species based on their interactions. One possible way to summarize interactions at this point would be to use the distance between the two species as a predictor, but this is already making a strong hypothesis, and we do not need to make it - instead, we will just throw a bunch of coordinates in the PCA space at our learner, and see what happens.

Removing the data that we think we have but we actually don’t

As it stands, it looks like we have a relatively large number of interactions, which is to say the 24926 entries in the metaweb. This is, of course, not correct, as a lot of these entries are pseudo-absences, as the (host,parasite) pair was not observed. If we were to just drop everything in a model, there would be a lot of things we think are zeroes, that are actually things we have no idea about, and therefore actual targets for prediction. So we need to filter our dataset so that we only keep the interactions where we actually know the information.

One tricky part is that a lot of these 0 can indeed be false negatives, but this is life. I wish I had a better answer to that, and I will in a minute when we show that the trained model will suggest a few, but this is the nature of ecological data.

Anyways, the following monstrosity is going to create an array with labels (true or false) for interactions, a matrix of features (the nf first positions of the two species), and an array with true or false for co-occurrence:

nf = 15
cooc = zeros(Bool, prod(size(M)))
labels = zeros(Bool, prod(size(M)))
features = zeros(Float64, (2*nf, prod(size(M))))
cursor = 0
for i in species(M; dims=1)
    for j in species(M; dims=2)
        global cursor += 1
        # Interaction in the metaweb?
        labels[cursor] = M[i,j]
        # Values in the PCA space
        p_i = findfirst(i .== species(N))
        p_j = findfirst(j .== species(N))
        features[1:nf, cursor] .= pr[1:nf,p_i]
        features[(nf+1):end, cursor] .= pr[1:nf,p_j]
        # Co-occurrence?
        for b in B
            if i in species(b)
                if j in species(b)
                    cooc[cursor] = true
                end
            end
        end
    end
end

At the end of this step, we have the complete feature vectors (we will use this for prediction later), the labels (which contains observed labels as well as the potential false negatives due to no co-occurrence), and the vector with co-occurrence for every entry in the metaweb. There have been a total of 6920 co-occurrence events, for a total of 2131 interactions. The actual prevalence is therefore 0.31, as opposed to 0.09 which we would have infered had we not removed the non-co-occurring pairs. This is a substantial difference, as our dataset is a lot more balanced!

Putting it all together

All that is left to do before we start moving on to the actual training part is to remove all pairs that we cannot discuss because they do not co-occur, and then massage the data into a way that Flux will like:

kept = findall(cooc)
x = Float32.(copy(features[:, kept]))
y = Flux.onehotbatch(labels[kept], [false, true])
2×6920 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 0  0  1  0  0  1  0  0  1  1  0  1  1  …  1  1  1  1  1  1  1  0  0  1  1 
 1
 1  1  0  1  1  0  1  1  0  0  1  0  0     0  0  0  0  0  0  0  1  1  0  0 
 0

Note that we are converting the features to Float32 - there is no need to keep the full precision here, so we will use half the memory and make the code run a bit faster.

And we are done with part 1. Showtime!

Preparing the neural network

In this section, we will create a few utility functions, define the neural network structure, and split the data into a training and evaluation set.

A few helper functions

To really get a sense of what is going on in the model, we will mostly be looking at the confusion matrix, which is a way of representing the true/false positive/negative entries. There are more elegant ways to do this, but this version is serviceable, so it will do:

function confusion_matrix(model, f, l)
    pred = Flux.onecold(model(f), [false, true])
    obs = Flux.onecold(l, [false, true])
    M = zeros(Int64, (2,2))
    M[1,1] = sum(pred .* obs)
    M[2,2] = sum(.!pred .* .!obs)
    M[1,2] = sum(pred .> obs)
    M[2,1] = sum(pred .< obs)
    return M
end
confusion_matrix (generic function with 1 method)

Based on the confusion matrix, we will want to look at four key measures of performance. Accuracy, the overall fraction of correct prediction. Sensitivity, the ability to classiify positives, and specificity, the ability to classify negatives. Finally, the true-skill statistic, which is a measure of overall predictive accuracy for which 0 means random, -1 means always wrong, and 1 means always right.

accuracy = (M) -> sum(diag(M)) / sum(M)
sensitivity = (M) -> M[1,1]/(M[1,1]+M[2,1])
specificity = (M) -> M[2,2]/(M[1,2]+M[2,2])
tss = (M) -> (M[1,1]*M[2,2]-M[1,2]*M[2,1])/((M[1,1]+M[2,1])*(M[1,2]+M[2,2]))
#11 (generic function with 1 method)

We have these four helper functions, so we are ready to start doing the actual work. Note that these functions don’t play any role in the training - they will be used to check how our model is learning.

Training and testing sets

We will use 70% of the data for training, and 30% for validation. This is going to lead to a relative small sample size, but this is all we have, and so this is what we will use.

training_size = convert(Int64, floor(size(x, 2)*0.7))
train = sort(sample(1:size(x,2), training_size, replace=false))
test = filter(i -> !(i in train), 1:size(x,2))

data = (x[:,train], y[:,train])
data_test = (x[:,test], y[:,test])

The neural network itself

Writing code with Flux is a joy, and the actual network looks like this:

m = Chain(
    Dense(2nf, nf, relu),
    Dropout(0.35),
    Dense(nf, 8, σ),
    Dropout(0.35),
    Dense(8, 2),
    softmax
)
Chain(Dense(30, 15, relu), Dropout(0.35), Dense(15, 8, σ), Dropout(0.35), D
ense(8, 2), softmax)

Let’s go through the step - we will use three densely connected layers. The input dimension is 2nf, because every species has nf features, and an interaction has 2 species. The first layer will reduce this to nf features, and use a rectified linear unit activation function. The next two layers will have sigmoid activitation, and bring the data from nf to 8 dimensions, and then from 8 to 2 (one for true, one for false); finally we will use softmax to maximise the differences between the outputs.

What about the dropout layers? They are here to only randomly activate some of the connections at every timestep. Using dropout is a good idea to avoid overfitting, as it ensures that we won’t ever update the entire weights at once. Learning will be slower, but it will also be better.

To finish, we will pick a loss function (logit cross-entropy), extract all parameters (Flux allows to only update some parameters if you want to work on a pre-trained network), and pick an optimizer (ADAM with the default rate of learning, which is I think 0.001).

loss(x, y) = Flux.logitcrossentropy(m(x), y)
ps = Flux.params(m)
opt = ADAM()

Let’s roll!

Training the neural network

To save time during the training, we will use batches of 32 metaweb entries. Why 32? No idea. It just looks nice. It’s faster than 64, and leads to less instability than 16, and numbers that are not prime or multiples of 8 are a disgrace anyways, so 32 is the only reasonable choice.

So, our training routine is going to look as follows: pick 32 metaweb entries, update the model, and look at the confusion matrix. This is the part that might take time (and in practice, this is the part that would happen using Flux batching system, and also take place on the GPU, but remember, the most important thing is to have fun).

We will store the confusion matrices in a 3D array, where the slice [:,:,i] will be the matrix after the ith batch.

n_batches, batch_size = 25000, 32
matrices_train = zeros(Int64, (2,2,n_batches))
matrices_test = zeros(Int64, (2,2,n_batches))

We can get to work - beware, this might be the long part (it took about 10 minutes on a single core of my laptop, but we’re talking i3 cores clocking in at an impressive 1.7Ghz, so hopefully you will have the good sense of doing this on a less craptacular machine). The length is due to the fact that we are using dropout (so we need more cycles to train the full network), and that we also use batches (so we need more cycles to go through the entire dataset multiple times).

for i in 1:n_batches
    ord = sample(train, batch_size, replace=false)
    data_batch = (x[:,ord], y[:,ord])
    Flux.train!(loss, ps, [data_batch], opt)
    matrices_test[:,:,i] = confusion_matrix(m, data_test...)
    matrices_train[:,:,i] = confusion_matrix(m, data...)
end

At the end of this process, our model is trained. But is it trained well? We’re about to find out.

Assessing the network behavior

In this section, before we can move on to prediction, we will pay a bit more attention to some measures of performance of our model. There is a whole laundry list of measures one can look at, but we will focus on accuracy, specificity, sensitivity, and TSS.

Do we overfit?

One good way to see if the model is overfitting, i.e. learning too much about the training set, including its little quirks, is to compare the loss function of the training and testing sets over the training time. We can do the same with the accuracy; if the model is overfitting, accuracy on the training data will keep increasing, while accuracy on the validation data will plateau. Neural networks (and deep neural networks in particular) are prone to overfitting, so this is very important to check.

plot(
    vec(mapslices(accuracy, matrices_train, dims=[1,2])),
    lab = "Training", ylab="Accuracy", c=:orange, 
    legend=:bottomright
    )
plot!(
    vec(mapslices(accuracy, matrices_test, dims=[1,2])),
    lab="Validation", c=:teal, lw=2.0
    )

Accuracy takes a little while to start changing sometimes, and this is due to a series of factors (ReLU can “die off”, and the dropout also doesn’t help, and maybe the batches are a bit too full of zeroes, etc etc), but it then increases sharply. Remember that accuracy is very simple to interpret, as it is the proportion of correct predictions: we want this to go towards one.

It is worth paying attention to the shape of this curve – in the early phases, the accuracies on both datasets change in the same way, but then the accuracy for the validation set increases slower than that of the training set. This is a sign that we are starting to overfit, as the model is getting better at predicting data it already saw, but not a lot better at predicting data it never did. Note that the difference in accuracy is not that large: it is about 2%. We can also compare the loss on both datasets:

loss(data...)
0.45606053f0
loss(data_test...)
0.47878224f0

The loss on the validation data has gotten larger than the loss on the training data, which is a Not Good thing. But we can also look at the graph, and decide that for the purpose of writing a blog post, we have not overfitted too much, and this is usable as is.

Specificity and sensitivity

Accuracy is one thing, but networks are notoriously full of zeroes, so we might want to check how the model works on zeroes and ones. We will first look at specificity, i.e. the ability to tell that a negative is indeed negative:

plot(
    vec(mapslices(specificity, matrices_train, dims=[1,2])),
    lab = "Training", ylab="Specificity", c=:orange, 
    legend=:bottomleft
    )
plot!(
    vec(mapslices(specificity, matrices_test, dims=[1,2])),
    lab="Validation", c=:teal, lw=2.0
    )

This is very close to 1, which means that the model is doing very well on the identification of “no interactions”. The next task, more challenging of course, is to look at the sensitivity, which does the same thing but with positives:

plot(
    vec(mapslices(sensitivity, matrices_train, dims=[1,2])),
    lab = "Training", ylab="Sensitivity", c=:orange, 
    legend=:bottomright
    )
plot!(
    vec(mapslices(sensitivity, matrices_test, dims=[1,2])),
    lab="Validation", c=:teal, lw=2.0
    )

It’s a little less glorious, with a value of around 0.4. This is not a good accuracy, but remember that this is a difficult problem, with a low volume of data, and that the predictive variables we use are not optimal. Even though the sensitivity is not 0.9, I’m still very happy with the fact that it’s not 0! It’s also interesting to notice that our model is getting better at identifying interactions in the training dataset faster than for the validation dataset, which is again a clue that we might be overfitting it.

Does it predict?

Let’s finally look at an overall measure of predictive accuracy, the TSS:

plot(
    vec(mapslices(tss, matrices_train, dims=[1,2])),
    lab = "Training", ylab="True skill", c=:orange, 
    legend=:bottomright
    )
plot!(
    vec(mapslices(tss, matrices_test, dims=[1,2])),
    lab="Validation", c=:teal, lw=2.0
    )

We are clearly moving towards positive values, so the model is doing better than random. These are not impressive values, per se, but given all of the giant leaps of faith we had to take to get here, it’s also not as disappointing as it could have been.

Let’s look at the confusion matrix at the end of the training period:

matrices_test[:,:,end]
2×2 Array{Int64,2}:
 384    87
 275  1330

The most frequent error that the model makes is false negatives, i.e. it will tend to undercount the number of interactions. Depending on why we are making a prediction, this might be a good, or a bad thing.

Conclusion: let’s make a prediction!

At this point, we might decide to use the model for prediction. This is very easy to do!

predictions = Flux.onecold(m(features), [false, true])

This is returning an array of boolean values, which has the predicted interactions for every pair of species. Because we know that our model tends to not make too many mistakes at detecting lack of interactions, but might underestimate interactions, I would feel safe in adding the interactions that are predicted both between the co-occuring species, and between the non-co-occurring ones.

Let’s do just that.

P = copy(M)
cursor = 0
for i in species(P; dims=1)
    for j in species(P; dims=2)
        global cursor += 1
        if !(P[i,j])
            P[i,j] = predictions[cursor]
        end
    end
end

The new metaweb has 2516, with a connectance of 0.1.

Of course, this outcome demands a lot of improvements. The structure of the neural network can be improved, the hyperparameters can be tweaked, and there are additional constraints we can apply to the batches we use for training. But this is a demonstration of how involved the process of learning a network can be; there are tahnkfully many ways to embed the network into a new subspace, and many ways to tweak the performance of the neural network (and no, we’re not giving away our cool secrets right now). But the take home message is: with a surprisingly small amount of data, it is indeed possible to make guesses as to how species that have never been observed together might interact.

And I just think that’s neat.