Cross-validated Unfold models

This tutorial shows how to run an Unfold model with k-fold cross-validation.

Important

cross validation is not yet implemented for deconvolution models - not because it is hard, but because I didnt do it yet

Setup

using Unfold
using UnfoldSim
using Random
using Statistics
using UnfoldMakie, CairoMakie
eeg, evts = UnfoldSim.predef_eeg(; return_epoched = true, noiselevel = 15)
([10.688811360704506 -21.962460781837876 โ€ฆ 13.155407280464756 7.004265290077822; -1.43459355551717 -19.02048695957436 โ€ฆ 7.846848082403746 -5.000311556800763; โ€ฆ ; -6.553487928276634 -23.8871311282609 โ€ฆ -2.5730011130260984 -21.772988738765104; -4.074053769206997 -15.0753294807487 โ€ฆ 3.937492801380999 -25.673145843995762], 2000ร—3 DataFrame
  Row โ”‚ continuous  condition  latency 
      โ”‚ Float64     String     Int64   
โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    1 โ”‚   2.77778   car             62
    2 โ”‚  -5.0       face           132
    3 โ”‚  -1.66667   car            196
    4 โ”‚  -5.0       car            249
    5 โ”‚   5.0       car            303
    6 โ”‚  -0.555556  car            366
    7 โ”‚  -2.77778   car            432
    8 โ”‚  -2.77778   face           483
  โ‹ฎ   โ”‚     โ‹ฎ           โ‹ฎ         โ‹ฎ
 1994 โ”‚   1.66667   car         119798
 1995 โ”‚   2.77778   car         119856
 1996 โ”‚  -5.0       car         119925
 1997 โ”‚   3.88889   car         119978
 1998 โ”‚  -5.0       face        120030
 1999 โ”‚  -0.555556  face        120096
 2000 โ”‚  -2.77778   car         120154
                      1985 rows omitted)

Define formula and basis function for a mass-univariate model.

f = @formula 0 ~ 1 + condition
FormulaTerm Response: 0 Predictors: 1 condition(unknown)

Cross-validation solver

solver_cv wraps the default solver, but runs it for each cross-validation fold

cv_solver = solver_cv(n_folds = 5, shuffle = true) # shuffle is true by default
(::Unfold.var"#cv_kernel#182"{Random.MersenneTwister, Int64, Bool, Unfold.var"#solver_cv##2#solver_cv##3", Bool}) (generic function with 1 method)

Now we can fit the model with the CV solver.

m_cv = fit(UnfoldModel, f, evts, eeg, 1:size(eeg, 1); solver = cv_solver)
Unfold-Type: ::UnfoldLinearModel{{Float64}}  Any => 1 + condition โœ” model is fit. size(coefs) (1, 44, 2) Useful functions: `design(uf)`, `designmatrix(uf)`, `coef(uf)`, `coeftable(uf)`

The 4th dimension contains the CV-fold (channel, time, coefficient, fold).

size(modelfit(m_cv).estimate)
(1, 44, 2, 5)

You also get train/test indices for each fold (we have to index once into the first "event", e.g. you could run multiple events / formulas in one model)

length(m_cv.modelfit.folds[1])
5

we can also access e.g. the third fold and check the train/test indices. Let's display only the first 6 indices

first(m_cv.modelfit.folds[1][3].train, 6)
6-element Vector{Int64}:
 1860
  362
 1216
  514
  879
 1850

coef and coeftable`

For LinearModelFitCV, coef(m) returns the mean over folds. This means coeftable(m_cv) reports the fold-averaged estimates.

first(coeftable(m_cv), 6)
6ร—7 DataFrame
Rowchannelcoefnameestimateeventnamegroupstderrortime
Int64StringFloat64DataTypeNothingNothingInt64
11(Intercept)1.30117Any1
21(Intercept)1.49691Any2
31(Intercept)1.55851Any3
41(Intercept)1.22707Any4
51(Intercept)1.03858Any5
61(Intercept)1.92573Any6

You can access fold-specific estimates directly from modelfit.estimate.

fold_1_estimate = modelfit(m_cv).estimate[:, :, :, 1]
size(fold_1_estimate)
(1, 44, 2)

Finally let's plot our estimates

f, ax, h = series(modelfit(m_cv).estimate[1, :, 1, :]')
lines!(coef(m_cv)[1, :, 1], color = :black, linestyle = :dash)
ax.xlabel = "Time (samples)"
f
Example block output

The colored lines are the fold-specific estimates, the dashed black line is the mean across folds.


This page was generated using Literate.jl.