Cross-validated Unfold models

This tutorial shows how to run an Unfold model with k-fold cross-validation.

Important

cross validation is not yet implemented for deconvolution models - not because it is hard, but because I didnt do it yet

Setup

using Unfold
using UnfoldSim
using Random
using Statistics
using UnfoldMakie, CairoMakie
eeg, evts = UnfoldSim.predef_eeg(; return_epoched = true, noiselevel = 15)
([10.688811360704506 -21.962460781837876 … 13.155407280464756 7.004265290077822; -1.43459355551717 -19.02048695957436 … 7.846848082403746 -5.000311556800763; … ; -6.553487928276634 -23.8871311282609 … -2.5730011130260984 -21.772988738765104; -4.074053769206997 -15.0753294807487 … 3.937492801380999 -25.673145843995762], 2000×3 DataFrame
  Row │ continuous  condition  latency 
      │ Float64     String     Int64   
──────┼────────────────────────────────
    1 │   2.77778   car             62
    2 │  -5.0       face           132
    3 │  -1.66667   car            196
    4 │  -5.0       car            249
    5 │   5.0       car            303
    6 │  -0.555556  car            366
    7 │  -2.77778   car            432
    8 │  -2.77778   face           483
  ⋮   │     ⋮           ⋮         ⋮
 1994 │   1.66667   car         119798
 1995 │   2.77778   car         119856
 1996 │  -5.0       car         119925
 1997 │   3.88889   car         119978
 1998 │  -5.0       face        120030
 1999 │  -0.555556  face        120096
 2000 │  -2.77778   car         120154
                      1985 rows omitted)

Define formula and basis function for a mass-univariate model.

f = @formula 0 ~ 1 + condition
FormulaTerm Response: 0 Predictors: 1 condition(unknown)

Cross-validation solver

solver_cv wraps the default solver, but runs it for each cross-validation fold

cv_solver = solver_cv(n_folds = 5, shuffle = true) # shuffle is true by default
(::Unfold.var"#cv_kernel#182"{Random.MersenneTwister, Int64, Bool, Unfold.var"#solver_cv##2#solver_cv##3", Bool}) (generic function with 1 method)

Now we can fit the model with the CV solver.

m_cv = fit(UnfoldModel, f, evts, eeg, 1:size(eeg, 1); solver = cv_solver)
Unfold�[38;2;239;83;80m-�[39mType: �[38;2;206;147;216m::UnfoldLinearModel�[39m{{Float64}} �[1m Any�[22m �[38;2;239;83;80m=�[39m�[38;2;239;83;80m>�[39m �[38;2;144;202;249m1�[39m �[38;2;239;83;80m+�[39m condition �[1m�[32m✔�[22m�[39m model is fit. size(coefs) (�[38;2;144;202;249m1�[39m, �[38;2;144;202;249m44�[39m, �[38;2;144;202;249m2�[39m) Useful functions: �[38;2;255;238;88m`design(uf)`�[39m, �[38;2;255;238;88m`designmatrix(uf)`�[39m, �[38;2;255;238;88m`coef(uf)`�[39m, �[38;2;255;238;88m`coeftable(uf)`�[39m

The 4th dimension contains the CV-fold (channel, time, coefficient, fold).

size(modelfit(m_cv).estimate)
(1, 44, 2, 5)

You also get train/test indices for each fold (we have to index once into the first "event", e.g. you could run multiple events / formulas in one model)

length(m_cv.modelfit.folds[1])
5

we can also access e.g. the third fold and check the train/test indices. Let's display only the first 6 indices

first(m_cv.modelfit.folds[1][3].train, 6)
6-element Vector{Int64}:
  378
   19
  526
   78
 1883
  255

coef and coeftable`

For LinearModelFitCV, coef(m) returns the mean over folds. This means coeftable(m_cv) reports the fold-averaged estimates.

first(coeftable(m_cv), 6)
6×7 DataFrame
Rowchannelcoefnameestimateeventnamegroupstderrortime
Int64StringFloat64DataTypeNothingNothingInt64
11(Intercept)1.302Any1
21(Intercept)1.49621Any2
31(Intercept)1.55901Any3
41(Intercept)1.22913Any4
51(Intercept)1.03961Any5
61(Intercept)1.92724Any6

You can access fold-specific estimates directly from modelfit.estimate.

fold_1_estimate = modelfit(m_cv).estimate[:, :, :, 1]
size(fold_1_estimate)
(1, 44, 2)

Finally let's plot our estimates

f, ax, h = series(modelfit(m_cv).estimate[1, :, 1, :]')
lines!(coef(m_cv)[1, :, 1], color = :black, linestyle = :dash)
ax.xlabel = "Time (samples)"
f
Example block output

The colored lines are the fold-specific estimates, the dashed black line is the mean across folds.


This page was generated using Literate.jl.