[Non-linear effects]](@id nonlinear)

using BSplineKit, Unfold
using CairoMakie
using DataFrames
using Random
using Colors
using Missings

Generating a non-linear signal

We start with generating data variables

rng = MersenneTwister(2) # make repeatable
n = 20 # number of datapoints
evts = DataFrame(:x => rand(rng, n))
signal = -(3 * (evts.x .- 0.5)) .^ 2 .+ 0.5 .* rand(rng, n)

plot(evts.x, signal)
Example block output

Looks perfectly non-linear. Great!

Compare linear & non-linear fit

First, we have to reshape signal data to a 3d array, so it will fit to Unfold format: 1 channel x 1 timepoint x 20 datapoints.

signal = reshape(signal, length(signal), 1, 1)
signal = permutedims(signal, [3, 2, 1])
size(signal)
(1, 1, 20)

Next we define three different models: linear, 4 splines and 10 splines. Note difference in formulas: one x, the other spl(x, 4).

design_linear = [Any => (@formula(0 ~ 1 + x), [0])];
design_spl3 = [Any => (@formula(0 ~ 1 + spl(x, 4)), [0])];
design_spl10 = [Any => (@formula(0 ~ 1 + spl(x, 10)), [0])];

Next, fit the parameters.

uf_linear = fit(UnfoldModel, design_linear, evts, signal);
uf_spl3 = fit(UnfoldModel, design_spl3, evts, signal);

Extract the fitted values using Unfold.effects.

p_linear = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_linear);
p_spl3 = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_spl3);
p_spl10 = Unfold.effects(Dict(:x => range(0, stop = 1, length = 100)), uf_spl10);
first(p_linear, 5)
5×5 DataFrame
Rowyhatchannelxtimeeventname
Float64Int64Float64Int64DataType
1-0.73668610.00Any
2-0.72730310.0101010Any
3-0.7179210.0202020Any
4-0.70853710.0303030Any
5-0.69915410.0404040Any

Plot them.

pl = plot(evts.x, signal[1, 1, :])
lines!(p_linear.x, p_linear.yhat)
lines!(p_spl3.x, coalesce.(p_spl3.yhat, NaN))
lines!(p_spl10.x, coalesce.(p_spl10.yhat, NaN))
pl
Example block output

We see here, that the linear effect (blue line) underfits the data, the yellow spl(x, 10) overfits it, but the green spl(x, 4) fits it perfectly.

Looking under the hood

Let's have a brief look how the splines manage what they are managing.

The most important bit to understand is, that we are replacing x by a set of coefficients spl(x). These new coefficients each tile the range of x (in our case, from [0-1]) in overlapping areas, while each will be fit by one coefficient. Because the ranges are overlapping, we get a smooth function.

Maybe this becomes clear after looking at a basisfunction:

term_spl = Unfold.formulas(uf_spl10)[1].rhs.terms[2]
spl(x, 10)

This is the spline term. Note, this is a special type available in the BSplineKit.jl extension in Unfold.jl. It's abstract type is AbstractSplineTerm defined in Unfold.jl

typeof(term_spl)
UnfoldBSplineKitExt.BSplineTerm{StatsModels.ContinuousTerm{Float64}, Int64}
const splFunction = Base.get_extension(Unfold, :UnfoldBSplineKitExt).splFunction
splFunction([0.2], term_spl)
1×10 Matrix{Float64}:
 0.0  0.254707  0.562326  0.182953  1.34739e-5  0.0  0.0  0.0  0.0  0.0

Each column of this 1-row matrix is a coefficient for our regression model.

lines(disallowmissing(splFunction([0.2], term_spl))[1, :])
Example block output

Note: We have to use disallowmissing, because our splines return a missing whenever we ask it to return a value outside its defined range, e.g.:

splFunction([-0.2], term_spl)
1×10 Matrix{Union{Missing, Float64}}:
 missing  missing  missing  missing  …  missing  missing  missing  missing

Because it never has seen any data outside and can't extrapolate!

Back to our main issue. Let's plot the whole basis set

basisSet = splFunction(0.0:0.01:1, term_spl)
basisSet = disallowmissing(basisSet[.!any(ismissing.(basisSet), dims = 2)[:, 1], :]) # remove missings
ax = Axis(Figure()[1, 1])
[lines!(ax, basisSet[:, k]) for k = 1:size(basisSet, 2)]
current_figure()
Example block output

Notice how we flipped the plot around, i.e. now on the x-axis we do not plot the coefficients, but the x-values. Now each line is one basis-function of the spline.

Unfold returns us one coefficient per basis-function

β = coef(uf_spl10)[1, 1, :]
β = Float64.(disallowmissing(β))
10-element Vector{Float64}:
 -0.22087936502751157
 -1.3266992189242772
 -1.2141127953366175
 -0.3601682760198117
  0.5375041304110395
  0.8663477473709839
  0.5202833888393346
 -0.4122550011276661
 -0.021571324406719897
 -0.696114834251856

But because we used an intercept, we have to do some remodelling in the basisSet.

X = hcat(ones(size(basisSet, 1)), basisSet[:, 1:5], basisSet[:, 7:end])
79×10 Matrix{Float64}:
 1.0  0.949272   0.0503154  0.000411356  …  0.0        0.0       0.0
 1.0  0.785544   0.206379   0.00800124      0.0        0.0       0.0
 1.0  0.641815   0.333447   0.0243138       0.0        0.0       0.0
 1.0  0.516783   0.433718   0.0482409       0.0        0.0       0.0
 1.0  0.409145   0.50939    0.0786746       0.0        0.0       0.0
 1.0  0.317599   0.562661   0.114507     …  0.0        0.0       0.0
 1.0  0.24084    0.59573    0.15463         0.0        0.0       0.0
 1.0  0.177567   0.610794   0.197935        0.0        0.0       0.0
 1.0  0.126477   0.610052   0.243314        0.0        0.0       0.0
 1.0  0.0862677  0.595703   0.28966         0.0        0.0       0.0
 ⋮                                       ⋱                       
 1.0  0.0        0.0        0.0          …  0.677104   0.123299  0.000400971
 1.0  0.0        0.0        0.0             0.640464   0.214438  0.00558484
 1.0  0.0        0.0        0.0             0.5705     0.314158  0.0222119
 1.0  0.0        0.0        0.0             0.477361   0.407302  0.0569693
 1.0  0.0        0.0        0.0             0.371193   0.478711  0.116544
 1.0  0.0        0.0        0.0          …  0.262145   0.513225  0.207623
 1.0  0.0        0.0        0.0             0.160365   0.495686  0.336894
 1.0  0.0        0.0        0.0             0.0759992  0.410936  0.511044
 1.0  0.0        0.0        0.0             0.0191963  0.243817  0.73676

Now we can weight the spline by the basisfunction.

weighted = (β .* X')
10×79 Matrix{Float64}:
 -0.220879     -0.220879    -0.220879     …  -0.220879    -0.220879
 -1.2594       -1.04218     -0.851496        -0.0         -0.0
 -0.0610886    -0.250568    -0.404843        -0.0         -0.0
 -0.000148157  -0.00288179  -0.00875705      -0.0         -0.0
  4.46531e-7    4.05203e-5   0.000227551      0.0          0.0
  0.0           0.0          0.0          …   0.0          0.0
  0.0           0.0          0.0              0.00105113   0.000118356
 -0.0          -0.0         -0.0             -0.031331    -0.00791377
 -0.0          -0.0         -0.0             -0.00886444  -0.00525945
 -0.0          -0.0         -0.0             -0.355745    -0.512869

Plotting them creates a nice looking plot!

ax = Axis(Figure()[1, 1])
[lines!(weighted[k, :]) for k = 1:10]
current_figure()
Example block output

Now sum them up.

lines(sum(weighted, dims = 1)[1, :])
plot!(X * β, color = "gray") #(same as matrixproduct X*β directly!)
current_figure()
Example block output

And this is how you can think about splines.


This page was generated using Literate.jl.