[Non-linear effects]](@id nonlinear)
import BSplineKit
using Unfold
using CairoMakie
using DataFrames
using Random
using Colors
using Missings
Generating a non-linear signal
We start with generating data variables
rng = MersenneTwister(2) # make repeatable
n = 20 # number of datapoints
evts = DataFrame(:x => rand(rng, n))
signal = -(3 * (evts.x .- 0.5)) .^ 2 .+ 0.5 .* rand(rng, n)
plot(evts.x, signal)

Looks perfectly non-linear. Great!
Compare linear & non-linear fit
First, we have to reshape signal
data to a 3d array, so it will fit to Unfold format: 1 channel x 1 timepoint x 20 datapoints.
signal = reshape(signal, length(signal), 1, 1)
signal = permutedims(signal, [3, 2, 1])
size(signal)
(1, 1, 20)
Next we define three different models: linear, 4 splines and 10 splines. Note difference in formulas: one x
, the other spl(x, 4)
.
design_linear = [Any => (@formula(0 ~ 1 + x), [0])];
design_spl3 = [Any => (@formula(0 ~ 1 + spl(x, 4)), [0])];
design_spl10 = [Any => (@formula(0 ~ 1 + spl(x, 10)), [0])];
Next, fit the parameters.
uf_linear = fit(UnfoldModel, design_linear, evts, signal);
uf_spl3 = fit(UnfoldModel, design_spl3, evts, signal);
Extract the fitted values using Unfold.effects.
p_linear = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_linear);
p_spl3 = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_spl3);
p_spl10 = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_spl10);
first(p_linear, 5)
Row | yhat | channel | x | time | eventname |
---|---|---|---|---|---|
Float64 | Int64 | Float64 | Int64 | DataType | |
1 | 0.0328538 | 1 | 0.0 | 0 | Any |
2 | 0.0273313 | 1 | 0.010101 | 0 | Any |
3 | 0.0218088 | 1 | 0.020202 | 0 | Any |
4 | 0.0162863 | 1 | 0.030303 | 0 | Any |
5 | 0.0107638 | 1 | 0.040404 | 0 | Any |
Plot them.
pl = plot(evts.x, signal[1, 1, :])
lines!(p_linear.x, p_linear.yhat)
lines!(p_spl3.x, coalesce.(p_spl3.yhat, NaN))
lines!(p_spl10.x, coalesce.(p_spl10.yhat, NaN))
pl

We see here, that the linear effect (blue line) underfits the data, the yellow spl(x, 10)
overfits it, but the green spl(x, 4)
fits it perfectly.
Looking under the hood
Let's have a brief look how the splines manage what they are managing.
The most important bit to understand is, that we are replacing x
by a set of coefficients spl(x)
. These new coefficients each tile the range of x
(in our case, from [0-1]) in overlapping areas, while each will be fit by one coefficient. Because the ranges are overlapping, we get a smooth function.
Maybe this becomes clear after looking at a basisfunction
:
term_spl = Unfold.formulas(uf_spl10)[1].rhs.terms[2]
spl(x, 10)
This is the spline term. Note, this is a special type available in the BSplineKit.jl extension in Unfold.jl. It's abstract type is AbstractSplineTerm
defined in Unfold.jl
typeof(term_spl)
UnfoldBSplineKitExt.BSplineTerm{StatsModels.ContinuousTerm{Float64}, Int64}
const splFunction = Base.get_extension(Unfold, :UnfoldBSplineKitExt).splFunction
splFunction([0.2], term_spl)
1×10 Matrix{Float64}:
0.492619 0.438047 0.0670761 0.00225775 0.0 0.0 0.0 0.0 0.0 0.0
Each column of this 1-row matrix is a coefficient for our regression model.
lines(disallowmissing(splFunction([0.2], term_spl))[1, :])

Note: We have to use disallowmissing
, because our splines return a missing
whenever we ask it to return a value outside its defined range, e.g.:
splFunction([-0.2], term_spl)
1×10 Matrix{Union{Missing, Float64}}:
missing missing missing missing … missing missing missing missing
Because it never has seen any data outside and can't extrapolate!
Back to our main issue. Let's plot the whole basis set
basisSet = splFunction(0.0:0.01:1, term_spl)
basisSet = disallowmissing(basisSet[.!any(ismissing.(basisSet), dims = 2)[:, 1], :]) # remove missings
ax = Axis(Figure()[1, 1])
[lines!(ax, basisSet[:, k]) for k = 1:size(basisSet, 2)]
current_figure()

Notice how we flipped the plot around, i.e. now on the x-axis we do not plot the coefficients, but the x
-values. Now each line is one basis-function of the spline.
Unfold returns us one coefficient per basis-function
β = coef(uf_spl10)[1, 1, :]
β = Float64.(disallowmissing(β))
10-element Vector{Float64}:
-0.43629459354570777
-0.3477308181329035
0.4538614574712754
-0.4065207930158754
0.7346158579524353
0.9252913204020701
0.27167896791779556
-0.046335871160806175
-0.5822988416277803
-0.6202082891592833
But because we used an intercept, we have to do some remodelling in the basisSet
.
X = hcat(ones(size(basisSet, 1)), basisSet[:, 1:5], basisSet[:, 7:end])
71×10 Matrix{Float64}:
1.0 0.972634 0.027217 0.000148709 … 0.0 0.0 0.0
1.0 0.705645 0.274339 0.0196948 0.0 0.0 0.0
1.0 0.492619 0.438047 0.0670761 0.0 0.0 0.0
1.0 0.327462 0.530129 0.135118 0.0 0.0 0.0
1.0 0.204084 0.56237 0.216645 0.0 0.0 0.0
1.0 0.116392 0.546557 0.304483 … 0.0 0.0 0.0
1.0 0.0582934 0.494476 0.391456 0.0 0.0 0.0
1.0 0.0236969 0.417914 0.470391 0.0 0.0 0.0
1.0 0.00651001 0.328658 0.534112 0.0 0.0 0.0
1.0 0.000640778 0.238493 0.575444 0.0 0.0 0.0
⋮ ⋱
1.0 0.0 0.0 0.0 0.500742 0.335973 0.0
1.0 0.0 0.0 0.0 0.472071 0.412937 0.0
1.0 0.0 0.0 0.0 0.423651 0.500839 0.0
1.0 0.0 0.0 0.0 … 0.354694 0.596875 0.00252938
1.0 0.0 0.0 0.0 0.272257 0.677546 0.0249338
1.0 0.0 0.0 0.0 0.186241 0.711843 0.0899373
1.0 0.0 0.0 0.0 0.106554 0.668749 0.220272
1.0 0.0 0.0 0.0 0.0431012 0.517244 0.438668
1.0 0.0 0.0 0.0 … 0.00579053 0.226308 0.767859
Now we can weight the spline by the basisfunction
.
weighted = (β .* X')
10×71 Matrix{Float64}:
-0.436295 -0.436295 -0.436295 … -0.436295 -0.436295
-0.338215 -0.245375 -0.171299 -0.0 -0.0
0.0123527 0.124512 0.198813 0.0 0.0
-6.04533e-5 -0.00800637 -0.0272678 -0.0 -0.0
1.39303e-7 0.00023577 0.00165858 0.0 0.0
0.0 0.0 0.0 … 0.0 0.0
0.0 0.0 0.0 0.000268113 1.15869e-5
-0.0 -0.0 -0.0 -0.00199713 -0.000268309
-0.0 -0.0 -0.0 -0.30119 -0.131779
-0.0 -0.0 -0.0 -0.272066 -0.476232
Plotting them creates a nice looking plot!
ax = Axis(Figure()[1, 1])
[lines!(weighted[k, :]) for k = 1:10]
current_figure()

Now sum them up.
lines(sum(weighted, dims = 1)[1, :])
plot!(X * β, color = "gray") #(same as matrixproduct X*β directly!)
current_figure()

And this is how you can think about splines.
This page was generated using Literate.jl.