Cross-validated Unfold models
This tutorial shows how to run an Unfold model with k-fold cross-validation.
cross validation is not yet implemented for deconvolution models - not because it is hard, but because I didnt do it yet
Setup
using Unfold
using UnfoldSim
using Random
using Statistics
using UnfoldMakie, CairoMakie
eeg, evts = UnfoldSim.predef_eeg(; return_epoched = true, noiselevel = 15)([10.688811360704506 -21.962460781837876 โฆ 13.155407280464756 7.004265290077822; -1.43459355551717 -19.02048695957436 โฆ 7.846848082403746 -5.000311556800763; โฆ ; -6.553487928276634 -23.8871311282609 โฆ -2.5730011130260984 -21.772988738765104; -4.074053769206997 -15.0753294807487 โฆ 3.937492801380999 -25.673145843995762], 2000ร3 DataFrame
Row โ continuous condition latency
โ Float64 String Int64
โโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 2.77778 car 62
2 โ -5.0 face 132
3 โ -1.66667 car 196
4 โ -5.0 car 249
5 โ 5.0 car 303
6 โ -0.555556 car 366
7 โ -2.77778 car 432
8 โ -2.77778 face 483
โฎ โ โฎ โฎ โฎ
1994 โ 1.66667 car 119798
1995 โ 2.77778 car 119856
1996 โ -5.0 car 119925
1997 โ 3.88889 car 119978
1998 โ -5.0 face 120030
1999 โ -0.555556 face 120096
2000 โ -2.77778 car 120154
1985 rows omitted)Define formula and basis function for a mass-univariate model.
f = @formula 0 ~ 1 + conditionFormulaTerm
Response:
0
Predictors:
1
condition(unknown)Cross-validation solver
solver_cv wraps the default solver, but runs it for each cross-validation fold
cv_solver = solver_cv(n_folds = 5, shuffle = true) # shuffle is true by default(::Unfold.var"#cv_kernel#182"{Random.MersenneTwister, Int64, Bool, Unfold.var"#solver_cv##2#solver_cv##3", Bool}) (generic function with 1 method)Now we can fit the model with the CV solver.
m_cv = fit(UnfoldModel, f, evts, eeg, 1:size(eeg, 1); solver = cv_solver)Unfold[38;2;239;83;80m-[39mType: [38;2;206;147;216m::UnfoldLinearModel[39m{{Float64}}
[1m Any[22m [38;2;239;83;80m=[39m[38;2;239;83;80m>[39m [38;2;144;202;249m1[39m [38;2;239;83;80m+[39m condition
[1m[32mโ[22m[39m model is fit. size(coefs) ([38;2;144;202;249m1[39m, [38;2;144;202;249m44[39m, [38;2;144;202;249m2[39m)
Useful functions: [38;2;255;238;88m`design(uf)`[39m, [38;2;255;238;88m`designmatrix(uf)`[39m, [38;2;255;238;88m`coef(uf)`[39m, [38;2;255;238;88m`coeftable(uf)`[39mThe 4th dimension contains the CV-fold (channel, time, coefficient, fold).
size(modelfit(m_cv).estimate)(1, 44, 2, 5)You also get train/test indices for each fold (we have to index once into the first "event", e.g. you could run multiple events / formulas in one model)
length(m_cv.modelfit.folds[1])5we can also access e.g. the third fold and check the train/test indices. Let's display only the first 6 indices
first(m_cv.modelfit.folds[1][3].train, 6)6-element Vector{Int64}:
1860
362
1216
514
879
1850coef and coeftable`
For LinearModelFitCV, coef(m) returns the mean over folds. This means coeftable(m_cv) reports the fold-averaged estimates.
first(coeftable(m_cv), 6)| Row | channel | coefname | estimate | eventname | group | stderror | time |
|---|---|---|---|---|---|---|---|
| Int64 | String | Float64 | DataType | Nothing | Nothing | Int64 | |
| 1 | 1 | (Intercept) | 1.30117 | Any | 1 | ||
| 2 | 1 | (Intercept) | 1.49691 | Any | 2 | ||
| 3 | 1 | (Intercept) | 1.55851 | Any | 3 | ||
| 4 | 1 | (Intercept) | 1.22707 | Any | 4 | ||
| 5 | 1 | (Intercept) | 1.03858 | Any | 5 | ||
| 6 | 1 | (Intercept) | 1.92573 | Any | 6 |
You can access fold-specific estimates directly from modelfit.estimate.
fold_1_estimate = modelfit(m_cv).estimate[:, :, :, 1]
size(fold_1_estimate)(1, 44, 2)Finally let's plot our estimates
f, ax, h = series(modelfit(m_cv).estimate[1, :, 1, :]')
lines!(coef(m_cv)[1, :, 1], color = :black, linestyle = :dash)
ax.xlabel = "Time (samples)"
f
The colored lines are the fold-specific estimates, the dashed black line is the mean across folds.
This page was generated using Literate.jl.