@@ -85,34 +85,39 @@ vector_rsa_model_mat <- function(design) {
85
85
# ' one of \code{"pearson"} or \code{"spearman"}.
86
86
# ' @param nperm Integer, number of permutations for statistical testing (default: 0).
87
87
# ' @param save_distributions Logical, whether to save full permutation distributions (default: FALSE).
88
+ # ' @param return_predictions Logical, whether to return per-observation similarity scores (default: FALSE).
88
89
# '
89
90
# ' @return A \code{vector_rsa_model} object (S3 class) containing references to the dataset, design, and function parameters.
90
91
# '
91
92
# ' @details
92
93
# ' The model references the already-precomputed cross-block data from the design.
94
+ # ' If `return_predictions` is TRUE, the output of `run_regional` or `run_searchlight`
95
+ # ' will include a `prediction_table` tibble containing the observation-level RSA scores.
93
96
# '
94
97
# ' @export
95
98
vector_rsa_model <- function (dataset , design ,
96
99
distfun = cordist(),
97
100
rsa_simfun = c(" pearson" , " spearman" ),
98
101
nperm = 0 ,
99
- save_distributions = FALSE ) {
102
+ save_distributions = FALSE ,
103
+ return_predictions = FALSE ) {
100
104
rsa_simfun <- match.arg(rsa_simfun )
101
105
102
106
assertthat :: assert_that(inherits(dataset , " mvpa_dataset" ))
103
107
assertthat :: assert_that(inherits(design , " vector_rsa_design" ),
104
108
msg = " Input must be a 'vector_rsa_design' object." )
105
109
106
- # Create the model spec, passing permutation parameters
110
+ # Create the model spec, passing permutation and prediction parameters
107
111
create_model_spec(
108
112
" vector_rsa_model" ,
109
113
dataset = dataset ,
110
114
design = design ,
111
115
distfun = distfun ,
112
116
rsa_simfun = rsa_simfun ,
113
- nperm = nperm , # Pass nperm
114
- compute_performance = TRUE ,
115
- save_distributions = save_distributions # Pass save_distributions
117
+ nperm = nperm ,
118
+ compute_performance = TRUE , # Assume performance is always computed
119
+ save_distributions = save_distributions ,
120
+ return_predictions = return_predictions # Pass the new flag
116
121
)
117
122
}
118
123
@@ -447,9 +452,9 @@ merge_results.vector_rsa_model <- function(obj, result_set, indices, id, ...) {
447
452
# Return standard error tibble structure
448
453
return (
449
454
tibble :: tibble(
450
- result = list (NULL ), # No results on error
451
- indices = list (indices ), # Keep indices for context
452
- performance = list (NULL ), # No performance on error
455
+ result = list (NULL ),
456
+ indices = list (indices ),
457
+ performance = list (NULL ),
453
458
id = id ,
454
459
error = TRUE ,
455
460
error_message = emessage
@@ -458,13 +463,10 @@ merge_results.vector_rsa_model <- function(obj, result_set, indices, id, ...) {
458
463
}
459
464
460
465
# Extract the scores computed by train_model.
461
- # Default processor likely stores train_model output in result_set$result[[1]].
462
- # Add checks for robustness.
463
466
if (! " result" %in% names(result_set ) || length(result_set $ result ) == 0 || is.null(result_set $ result [[1 ]])) {
464
- error_msg <- " merge_results (vector_rsa): result_set missing or has NULL/empty 'result' field."
467
+ error_msg <- " merge_results (vector_rsa): result_set missing or has NULL/empty 'result' field where scores were expected ."
465
468
futile.logger :: flog.error(" ROI/Sphere ID %s: %s" , id , error_msg )
466
- # Create NA performance matrix to avoid downstream errors
467
- # Get expected metric names (rsa_score + perm cols if needed)
469
+ # Create NA performance matrix
468
470
perf_names <- " rsa_score"
469
471
if (obj $ nperm > 0 ) {
470
472
perf_names <- c(perf_names , " p_rsa_score" , " z_rsa_score" )
@@ -492,61 +494,55 @@ merge_results.vector_rsa_model <- function(obj, result_set, indices, id, ...) {
492
494
id = id , error = TRUE , error_message = error_msg ))
493
495
}
494
496
495
- # Call evaluate_model, passing the scores and permutation parameters from obj
497
+ # Call evaluate_model to compute summary performance and permutations
496
498
perf <- evaluate_model.vector_rsa_model(
497
- object = obj , # Pass the full model spec
498
- predicted = NULL , # Not used by vector_rsa evaluate
499
- observed = scores , # Pass the scores here
500
- nperm = obj $ nperm , # Get nperm from the model spec
501
- save_distributions = obj $ save_distributions # Get save_dist from model spec
499
+ object = obj ,
500
+ predicted = NULL ,
501
+ observed = scores ,
502
+ nperm = obj $ nperm ,
503
+ save_distributions = obj $ save_distributions
502
504
)
503
505
504
- # --- Collate results into the performance matrix ---
505
- base_metrics <- c(
506
- perf $ rsa_score # Extract the primary score
507
- )
508
- base_names <- c(" rsa_score" ) # Name it
506
+ # --- Collate performance matrix ---
507
+ base_metrics <- c(perf $ rsa_score )
508
+ base_names <- c(" rsa_score" )
509
509
510
- # Add permutation results if they were computed (even if NA)
511
510
if (! is.null(perf $ permutation_results )) {
512
511
perm_p_values <- perf $ permutation_results $ p_values
513
512
perm_z_scores <- perf $ permutation_results $ z_scores
514
-
515
- # Check if p-values/z-scores are named correctly
516
513
if (is.null(names(perm_p_values )) || is.null(names(perm_z_scores ))){
517
- p_names <- paste0(" p_" , base_names ) # Fallback naming
514
+ p_names <- paste0(" p_" , base_names )
518
515
z_names <- paste0(" z_" , base_names )
519
516
} else {
520
517
p_names <- paste0(" p_" , names(perm_p_values ))
521
518
z_names <- paste0(" z_" , names(perm_z_scores ))
522
519
}
523
-
524
520
perf_values <- c(base_metrics , perm_p_values , perm_z_scores )
525
521
perf_names <- c(base_names , p_names , z_names )
526
522
} else {
527
523
perf_values <- base_metrics
528
524
perf_names <- base_names
529
525
}
530
-
531
- # Create the performance matrix
532
- perf_mat <- matrix (
533
- perf_values ,
534
- nrow = 1 ,
535
- ncol = length(perf_values ),
536
- dimnames = list (NULL , perf_names )
537
- )
538
-
539
- # Remove columns that are all NA (e.g., if permutations failed or weren't run)
526
+ perf_mat <- matrix (perf_values , nrow = 1 , ncol = length(perf_values ), dimnames = list (NULL , perf_names ))
540
527
perf_mat <- perf_mat [, colSums(is.na(perf_mat )) < nrow(perf_mat ), drop = FALSE ]
541
528
529
+ # --- Prepare results structure based on return_predictions flag ---
530
+ result_data <- if (isTRUE(obj $ return_predictions )) {
531
+ # Return scores structured for later assembly into prediction_table
532
+ # Wrap scores in a list with a standard name
533
+ list (rsa_scores = scores )
534
+ } else {
535
+ NULL # Return NULL if predictions are not requested
536
+ }
537
+
542
538
# Return the final tibble structure expected by the framework
543
539
tibble :: tibble(
544
- result = list (NULL ), # Don't store raw results after merging
540
+ result = list (result_data ), # Store list(rsa_scores=scores) or NULL here
545
541
indices = list (indices ),
546
542
performance = list (perf_mat ),
547
543
id = id ,
548
544
error = FALSE ,
549
- error_message = " ~" # Indicate success
545
+ error_message = " ~"
550
546
)
551
547
}
552
548
0 commit comments