Source code for cfmap.CF_mapping

import os
import sys
import argparse
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import nibabel as nib
import neuropythy as ny
import scipy.stats as stats
from nilearn import signal
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.stats import pearsonr, kendalltau
from scipy.spatial import procrustes
from pingouin import circ_corrcl

from prfpy.stimulus import CFStimulus
from prfpy.model import CFGaussianModel
from prfpy.fit import CFFitter 

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import Wedge


[docs]def optimize_connfield_dfree( gf, r2_threshold=0.05, sigma_bounds=[(0.1, 20.0)], method='L-BFGS-B', tol=0.001, verbose=True ): """Optimize connective field parameters using BFGS starting from grid search results. Only optimizes sigma while keeping vertex center fixed (standard CF approach). Args: gf (CFFitter): The CFFitter object after grid_fit has been run r2_threshold (float): R² threshold for optimization (default 0.05) verbose (bool): Print progress information Returns: dict: Contains optimized parameters: - 'vertex_indices': Fixed vertex centers from grid search - 'sigma': Optimized sigma values - 'beta': Optimized beta (amplitude) values - 'baseline': Optimized baseline values - 'r2': r2 values for optimized fits """ from scipy.optimize import minimize from tqdm import tqdm ## Get grid search results best_vertex_indices = gf.gridsearch_params[:, 0].astype(int) best_sigmas = gf.gridsearch_params[:, 1] r2_mask = gf.gridsearch_r2 >= r2_threshold #r2_mask = gf.gridsearch_params[:, 4] >= r2_threshold n_sites = gf.data.shape[0] n_optimize = r2_mask.sum() if verbose: print(f"Optimizing {n_optimize} of {n_sites} sites (r2 >= {r2_threshold})") # Get distance matrix and source data from stimulus distance_matrix = gf.model.stimulus.distance_matrix source_data = gf.model.stimulus.design_matrix # n_vertices x n_timepoints def optimize_single_voxel(voxel_idx): """Optimize sigma for a single voxel""" vertex_idx = int(best_vertex_indices[voxel_idx]) sigma_init = best_sigmas[voxel_idx] target_ts = gf.data[voxel_idx] # Get distances from this vertex center to all source vertices distances = distance_matrix[vertex_idx, :] # Shape: (n_vertices,) def objective(sigma): """Objective function: minimize negative correlation""" # Calculate CF weights weights = np.exp(-(distances**2 / (2 * sigma[0]**2))) # Shape: (n_vertices,) # Create CF timecourse - weighted sum across vertices # source_data is (n_vertices, n_timepoints) # weights is (n_vertices,) cf_ts = np.dot(weights, source_data) # Shape: (n_timepoints,) # Calculate correlation r = np.corrcoef(target_ts, cf_ts)[0, 1] # Return negative correlation (to minimize) return -r if not np.isnan(r) else 1.0 # Optimize result = minimize( objective, x0=[sigma_init], bounds=sigma_bounds, method=method, tol=tol ) # Calculate final r2 sigma_opt = result.x[0] weights_opt = np.exp(-(distances**2 / (2 * sigma_opt**2))) cf_ts_opt = np.dot(weights_opt, source_data) r_opt = np.corrcoef(target_ts, cf_ts_opt)[0, 1] r2_opt = r_opt**2 if not np.isnan(r_opt) else 0.0 # Estimate beta and baseline using least squares X = np.column_stack([cf_ts_opt, np.ones_like(cf_ts_opt)]) params = np.linalg.lstsq(X, target_ts, rcond=None)[0] beta_opt = params[0] baseline_opt = params[1] return { 'vertex_idx': vertex_idx, 'sigma': sigma_opt, 'beta': beta_opt, 'baseline': baseline_opt, 'r2': r2_opt } # Run optimization with progress bar sites_to_optimize = np.where(r2_mask)[0] results = [] for voxel_idx in tqdm(sites_to_optimize, desc="Optimizing CFs", disable=not verbose): results.append(optimize_single_voxel(voxel_idx)) # Initialize output arrays with grid search values optimized_params = { 'vertex_indices': best_vertex_indices.copy(), 'sigma': best_sigmas.copy(), 'beta': gf.gridsearch_params[:, 2].copy(), 'baseline': gf.gridsearch_params[:, 3].copy(), 'r2': gf.gridsearch_r2.copy() } # Update with optimized values for i, voxel_idx in enumerate(sites_to_optimize): optimized_params['sigma'][voxel_idx] = results[i]['sigma'] optimized_params['beta'][voxel_idx] = results[i]['beta'] optimized_params['baseline'][voxel_idx] = results[i]['baseline'] optimized_params['r2'][voxel_idx] = results[i]['r2'] if verbose: sigma_improvement = optimized_params['sigma'][r2_mask] - best_sigmas[r2_mask] r2_improvement = optimized_params['r2'][r2_mask] - gf.gridsearch_r2[r2_mask] print(f"\nOptimization complete!") print(f"Mean sigma change: {np.mean(np.abs(sigma_improvement)):.3f} mm") print(f"Mean r2 improvement: {np.mean(r2_improvement):.4f}") print(f"Optimized sigma range: {optimized_params['sigma'][r2_mask].min():.2f} - {optimized_params['sigma'][r2_mask].max():.2f} mm") return optimized_params
[docs]def optimize_connfield_gdescent( gf, r2_threshold=0.05, sigma_bounds=(0.1, 20.0), learning_rate=0.01, max_iterations=1000, convergence_threshold=1e-4, batch_size=128, verbose=True ): """GPU-accelerated PARALLEL CF optimization using TensorFlow.""" import tensorflow as tf import numpy as np from tqdm import tqdm tf.keras.backend.clear_session() gpus = tf.config.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError: pass ## Get grid search results best_vertex_indices = gf.gridsearch_params[:, 0].astype(np.int32) best_sigmas = gf.gridsearch_params[:, 1].astype(np.float32) r2_mask = gf.gridsearch_r2 >= r2_threshold n_sites = gf.data.shape[0] n_optimize = r2_mask.sum() if verbose: print(f"\n{'='*60}") print("TENSORFLOW GPU PARALLEL OPTIMIZATION") print(f"{'='*60}") print(f"Optimizing {n_optimize} of {n_sites} sites (r2 >= {r2_threshold})") print(f"Batch size: {batch_size}") if gpus: print(f"GPU: {gpus[0].name}") else: print("⚠ WARNING: No GPU detected!") # Get data distance_matrix = gf.model.stimulus.distance_matrix.astype(np.float32) source_data = gf.model.stimulus.design_matrix.astype(np.float32) target_data = gf.data.astype(np.float32) # Convert to TF constants (stays on GPU) distance_matrix_tf = tf.constant(distance_matrix, dtype=tf.float32) source_data_tf = tf.constant(source_data, dtype=tf.float32) # Convert bounds to TF constants sigma_min = tf.constant(sigma_bounds[0], dtype=tf.float32) sigma_max = tf.constant(sigma_bounds[1], dtype=tf.float32) # Initialize outputs optimized_params = { 'vertex_indices': best_vertex_indices.copy(), 'sigma': best_sigmas.copy(), 'beta': gf.gridsearch_params[:, 2].copy(), 'baseline': gf.gridsearch_params[:, 3].copy(), 'r2': gf.gridsearch_r2.copy() } sites_to_optimize = np.where(r2_mask)[0] n_batches = int(np.ceil(n_optimize / batch_size)) @tf.function def optimize_batch_parallel(log_sigmas, vertex_indices, distances_batch, target_batch, source_data_matrix, sigma_min_val, sigma_max_val): """Optimize entire batch in parallel on GPU - ALL VARIABLES AS ARGUMENTS""" # Transform sigmas (all at once) sigmas = tf.exp(log_sigmas) sigmas = tf.clip_by_value(sigmas, sigma_min_val, sigma_max_val) # Compute weights for all sites: [batch_size, n_source_vertices] sigmas_expanded = tf.expand_dims(sigmas, axis=1) # [batch_size, 1] weights = tf.exp(-(distances_batch ** 2) / (2 * sigmas_expanded ** 2)) # Compute CF timecourses for all sites: [batch_size, n_timepoints] cf_timecourses = tf.matmul(weights, source_data_matrix) # Normalize CF timecourses cf_mean = tf.reduce_mean(cf_timecourses, axis=1, keepdims=True) cf_std = tf.math.reduce_std(cf_timecourses, axis=1, keepdims=True) + 1e-8 cf_normalized = (cf_timecourses - cf_mean) / cf_std # Normalize target timecourses target_mean = tf.reduce_mean(target_batch, axis=1, keepdims=True) target_std = tf.math.reduce_std(target_batch, axis=1, keepdims=True) + 1e-8 target_normalized = (target_batch - target_mean) / target_std # Compute correlations correlations = tf.reduce_mean(cf_normalized * target_normalized, axis=1) # Return negative correlation as loss loss = -tf.reduce_mean(correlations) return loss, sigmas # Process in batches for batch_idx in tqdm(range(n_batches), desc="GPU batches", disable=not verbose): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, n_optimize) batch_voxel_indices = sites_to_optimize[start_idx:end_idx] batch_size_actual = len(batch_voxel_indices) # Get batch data batch_vertices = best_vertex_indices[batch_voxel_indices] batch_init_sigmas = best_sigmas[batch_voxel_indices] # Initialize log sigmas for batch log_sigmas_init = np.log(np.clip(batch_init_sigmas, sigma_bounds[0] + 0.01, sigma_bounds[1] - 0.01)) log_sigmas = tf.Variable(log_sigmas_init, dtype=tf.float32, trainable=True) # Get distances for all sites in batch: [batch_size, n_source_vertices] distances_batch = tf.gather(distance_matrix_tf, batch_vertices, axis=0) # Get target data for batch: [batch_size, n_timepoints] target_batch = tf.constant(target_data[batch_voxel_indices], dtype=tf.float32) # Create single optimizer for entire batch optimizer = tf.optimizers.Adam(learning_rate=learning_rate) prev_loss = float('inf') no_improvement = 0 # Optimize all sites in batch simultaneously for iteration in range(max_iterations): with tf.GradientTape() as tape: # Pass ALL variables as arguments (no closures!) loss, sigmas_opt = optimize_batch_parallel( log_sigmas, batch_vertices, distances_batch, target_batch, source_data_tf, sigma_min, sigma_max # Pass bounds as arguments ) # Compute gradients for all sigmas at once gradients = tape.gradient(loss, [log_sigmas]) optimizer.apply_gradients(zip(gradients, [log_sigmas])) # Check convergence every 10 iterations if iteration % 10 == 0: current_loss = float(loss.numpy()) if abs(prev_loss - current_loss) < convergence_threshold: no_improvement += 1 if no_improvement >= 3: break else: no_improvement = 0 prev_loss = current_loss # Extract final optimized sigmas _, sigmas_final = optimize_batch_parallel( log_sigmas, batch_vertices, distances_batch, target_batch, source_data_tf, sigma_min, sigma_max ) sigmas_opt_np = sigmas_final.numpy() # Compute final parameters (this part is CPU-bound but fast) for i, voxel_idx in enumerate(batch_voxel_indices): vertex_idx = int(batch_vertices[i]) sigma_opt = float(sigmas_opt_np[i]) # Compute final CF timecourse distances_np = distance_matrix[vertex_idx, :] weights_opt = np.exp(-(distances_np ** 2) / (2 * sigma_opt ** 2)) cf_ts_opt = np.dot(weights_opt, source_data) # Fit beta and baseline target_ts_np = target_data[voxel_idx] X = np.column_stack([cf_ts_opt, np.ones_like(cf_ts_opt)]) params = np.linalg.lstsq(X, target_ts_np, rcond=None)[0] beta_opt = float(params[0]) baseline_opt = float(params[1]) # Compute R² r_opt = np.corrcoef(target_ts_np, cf_ts_opt)[0, 1] r2_opt = float(r_opt**2) if not np.isnan(r_opt) else 0.0 # Store results optimized_params['sigma'][voxel_idx] = sigma_opt optimized_params['beta'][voxel_idx] = beta_opt optimized_params['baseline'][voxel_idx] = baseline_opt optimized_params['r2'][voxel_idx] = r2_opt tf.keras.backend.clear_session() if verbose: sigma_improvement = optimized_params['sigma'][r2_mask] - best_sigmas[r2_mask] r2_improvement = optimized_params['r2'][r2_mask] - gf.gridsearch_r2[r2_mask] print(f"\n{'='*60}") print("OPTIMIZATION COMPLETE") print(f"{'='*60}") print(f"\nSigma changes:") print(f" Mean |Δσ|: {np.mean(np.abs(sigma_improvement)):.3f} mm") print(f" Max |Δσ|: {np.max(np.abs(sigma_improvement)):.3f} mm") print(f"\nR² improvements:") print(f" Mean ΔR²: {np.mean(r2_improvement):.4f}") print(f" Sites improved: {(r2_improvement > 0).sum()} ({100*(r2_improvement > 0).sum()/n_optimize:.1f}%)") print(f"{'='*60}") return optimized_params
[docs]def optimize_connfield_joint( gf, r2_threshold=0.05, sigma_bounds=(0.1, 20.0), max_outer_iterations=3, max_inner_iterations=300, search_radius=10.0, batch_size=256, learning_rate=0.01, verbose=True ): """GPU-parallelized alternating position-sigma optimization. Processes multiple sites in parallel on GPU for both: 1. Position search (vectorized evaluation) 2. Sigma optimization (batched gradient descent) Returns: dict: Contains optimized parameters and cycle-wise statistics """ import tensorflow as tf import numpy as np from tqdm import tqdm tf.keras.backend.clear_session() gpus = tf.config.list_physical_devices('GPU') if gpus: for gpu in gpus: try: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError: pass ## Get grid search results best_vertex_indices = gf.gridsearch_params[:, 0].astype(np.int32) best_sigmas = gf.gridsearch_params[:, 1].astype(np.float32) r2_mask = gf.gridsearch_r2 >= r2_threshold n_sites = gf.data.shape[0] n_optimize = r2_mask.sum() if verbose: print(f"\n{'='*60}") print("GPU-PARALLELIZED ALTERNATING OPTIMIZATION") print(f"{'='*60}") print(f"Optimizing {n_optimize} of {n_sites} sites") print(f"Cycles: {max_outer_iterations}, Batch: {batch_size}") if gpus: print(f"GPU: {gpus[0].name}") else: print("⚠ WARNING: No GPU!") # Get data distance_matrix = gf.model.stimulus.distance_matrix.astype(np.float32) source_data = gf.model.stimulus.design_matrix.astype(np.float32) target_data = gf.data.astype(np.float32) # Move to GPU distance_matrix_tf = tf.constant(distance_matrix, dtype=tf.float32) source_data_tf = tf.constant(source_data, dtype=tf.float32) target_data_tf = tf.constant(target_data, dtype=tf.float32) # Convert bounds to TF constants sigma_min = tf.constant(sigma_bounds[0], dtype=tf.float32) sigma_max = tf.constant(sigma_bounds[1], dtype=tf.float32) # Pre-compute neighbor lists (CPU) n_source_vertices = distance_matrix.shape[0] max_neighbors = 0 neighbor_lists = [] for vertex_idx in range(n_source_vertices): neighbors = np.where(distance_matrix[vertex_idx, :] <= search_radius)[0] neighbor_lists.append(neighbors) max_neighbors = max(max_neighbors, len(neighbors)) # Pad neighbor lists for GPU (all same length) neighbor_array = np.full((n_source_vertices, max_neighbors), -1, dtype=np.int32) neighbor_counts = np.zeros(n_source_vertices, dtype=np.int32) for i, neighbors in enumerate(neighbor_lists): n = len(neighbors) neighbor_array[i, :n] = neighbors neighbor_counts[i] = n neighbor_array_tf = tf.constant(neighbor_array, dtype=tf.int32) neighbor_counts_tf = tf.constant(neighbor_counts, dtype=tf.int32) if verbose: print(f"Max neighbors: {max_neighbors}, Search radius: {search_radius} mm") # Initialize current state current_vertices = best_vertex_indices.copy() current_sigmas = best_sigmas.copy() current_r2 = gf.gridsearch_r2.copy() sites_to_optimize = np.where(r2_mask)[0] n_batches = int(np.ceil(n_optimize / batch_size)) # Initialize cycle statistics storage cycle_stats = { 'cycle': [], 'n_improved': [], 'percent_improved': [], 'n_position_changes': [], 'percent_position_changes': [], 'mean_sigma_change': [], 'median_sigma_change': [], 'max_sigma_change': [], 'mean_r2': [], 'mean_r2_improvement': [], 'median_r2_improvement': [] } @tf.function def evaluate_positions_batch(vertex_candidates, sigmas, distances_batch, target_batch, source_data_matrix, sigma_min_val, sigma_max_val): """Evaluate all candidate positions for a batch of sites in parallel.""" sigmas_expanded = tf.reshape(sigmas, [-1, 1, 1]) weights = tf.exp(-(distances_batch ** 2) / (2 * sigmas_expanded ** 2)) cf_timeseries = tf.einsum('bcv,vt->bct', weights, source_data_matrix) cf_mean = tf.reduce_mean(cf_timeseries, axis=2, keepdims=True) cf_std = tf.math.reduce_std(cf_timeseries, axis=2, keepdims=True) + 1e-8 cf_normalized = (cf_timeseries - cf_mean) / cf_std target_expanded = tf.expand_dims(target_batch, axis=1) target_mean = tf.reduce_mean(target_expanded, axis=2, keepdims=True) target_std = tf.math.reduce_std(target_expanded, axis=2, keepdims=True) + 1e-8 target_normalized = (target_expanded - target_mean) / target_std correlations = tf.reduce_mean(cf_normalized * target_normalized, axis=2) return correlations @tf.function def optimize_sigma_batch(log_sigmas, vertex_indices, distances_batch, target_batch, source_data_matrix, sigma_min_val, sigma_max_val): """Optimize sigma for a batch of sites (fixed positions) in parallel.""" sigmas = tf.exp(log_sigmas) sigmas = tf.clip_by_value(sigmas, sigma_min_val, sigma_max_val) sigmas_expanded = tf.expand_dims(sigmas, axis=1) weights = tf.exp(-(distances_batch ** 2) / (2 * sigmas_expanded ** 2)) cf_timeseries = tf.matmul(weights, source_data_matrix) cf_mean = tf.reduce_mean(cf_timeseries, axis=1, keepdims=True) cf_std = tf.math.reduce_std(cf_timeseries, axis=1, keepdims=True) + 1e-8 cf_normalized = (cf_timeseries - cf_mean) / cf_std target_mean = tf.reduce_mean(target_batch, axis=1, keepdims=True) target_std = tf.math.reduce_std(target_batch, axis=1, keepdims=True) + 1e-8 target_normalized = (target_batch - target_mean) / target_std correlations = tf.reduce_mean(cf_normalized * target_normalized, axis=1) loss = -tf.reduce_mean(correlations) return loss, sigmas # Store initial state for comparison initial_vertices = current_vertices.copy() initial_sigmas = current_sigmas.copy() initial_r2 = current_r2.copy() # ALTERNATING OPTIMIZATION CYCLES for outer_iter in range(max_outer_iterations): if verbose: print(f"\n--- Cycle {outer_iter + 1}/{max_outer_iterations} ---") # Store state at beginning of cycle cycle_start_vertices = current_vertices.copy() cycle_start_sigmas = current_sigmas.copy() cycle_start_r2 = current_r2.copy() cycle_improvements = 0 # Process in batches for batch_idx in tqdm(range(n_batches), desc=f"Cycle {outer_iter+1}", disable=not verbose): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, n_optimize) batch_voxel_indices = sites_to_optimize[start_idx:end_idx] batch_size_actual = len(batch_voxel_indices) batch_vertices = current_vertices[batch_voxel_indices] batch_sigmas = current_sigmas[batch_voxel_indices] batch_r2 = current_r2[batch_voxel_indices] batch_target = target_data[batch_voxel_indices] # STEP 1: OPTIMIZE POSITION batch_neighbors = neighbor_array[batch_vertices] batch_n_neighbors = neighbor_counts[batch_vertices] batch_neighbors_tf = tf.constant(batch_neighbors, dtype=tf.int32) batch_sigmas_tf = tf.constant(batch_sigmas, dtype=tf.float32) batch_target_tf = tf.constant(batch_target, dtype=tf.float32) distances_batch = tf.gather(distance_matrix_tf, batch_neighbors_tf, axis=0) valid_mask = batch_neighbors_tf >= 0 correlations = evaluate_positions_batch( batch_neighbors_tf, batch_sigmas_tf, distances_batch, batch_target_tf, source_data_tf, sigma_min, sigma_max ) correlations = tf.where(valid_mask, correlations, -1.0) best_candidate_indices = tf.argmax(correlations, axis=1, output_type=tf.int32) best_candidate_indices_np = best_candidate_indices.numpy() new_vertices = np.array([ batch_neighbors[i, best_candidate_indices_np[i]] for i in range(batch_size_actual) ]) # STEP 2: OPTIMIZE SIGMA log_sigmas_init = np.log(np.clip(batch_sigmas, sigma_bounds[0] + 0.01, sigma_bounds[1] - 0.01)) log_sigmas = tf.Variable(log_sigmas_init, dtype=tf.float32, trainable=True) new_vertices_tf = tf.constant(new_vertices, dtype=tf.int32) distances_batch_sigma = tf.gather(distance_matrix_tf, new_vertices_tf, axis=0) optimizer = tf.optimizers.Adam(learning_rate=learning_rate) prev_loss = float('inf') no_improvement = 0 for inner_iter in range(max_inner_iterations): with tf.GradientTape() as tape: loss, _ = optimize_sigma_batch( log_sigmas, new_vertices_tf, distances_batch_sigma, batch_target_tf, source_data_tf, sigma_min, sigma_max ) gradients = tape.gradient(loss, [log_sigmas]) optimizer.apply_gradients(zip(gradients, [log_sigmas])) if inner_iter % 10 == 0: current_loss = float(loss.numpy()) if abs(prev_loss - current_loss) < 1e-4: no_improvement += 1 if no_improvement >= 3: break else: no_improvement = 0 prev_loss = current_loss _, sigmas_opt_tf = optimize_sigma_batch( log_sigmas, new_vertices_tf, distances_batch_sigma, batch_target_tf, source_data_tf, sigma_min, sigma_max ) new_sigmas = sigmas_opt_tf.numpy() # STEP 3: Compute final R² and update if improved for i in range(batch_size_actual): voxel_idx = batch_voxel_indices[i] vertex_idx = int(new_vertices[i]) sigma = float(new_sigmas[i]) target_ts = batch_target[i] distances_np = distance_matrix[vertex_idx, :] weights = np.exp(-(distances_np ** 2) / (2 * sigma ** 2)) cf_ts = np.dot(weights, source_data) r = np.corrcoef(target_ts, cf_ts)[0, 1] r2 = r**2 if not np.isnan(r) else 0.0 if r2 > batch_r2[i]: current_vertices[voxel_idx] = vertex_idx current_sigmas[voxel_idx] = sigma current_r2[voxel_idx] = r2 cycle_improvements += 1 # Calculate cycle statistics position_changes = (current_vertices[r2_mask] != cycle_start_vertices[r2_mask]) n_position_changes = position_changes.sum() sigma_changes = current_sigmas[r2_mask] - cycle_start_sigmas[r2_mask] abs_sigma_changes = np.abs(sigma_changes) r2_improvements = current_r2[r2_mask] - cycle_start_r2[r2_mask] # Store cycle statistics cycle_stats['cycle'].append(outer_iter + 1) cycle_stats['n_improved'].append(int(cycle_improvements)) cycle_stats['percent_improved'].append(float(100 * cycle_improvements / n_optimize)) cycle_stats['n_position_changes'].append(int(n_position_changes)) cycle_stats['percent_position_changes'].append(float(100 * n_position_changes / n_optimize)) cycle_stats['mean_sigma_change'].append(float(abs_sigma_changes.mean())) cycle_stats['median_sigma_change'].append(float(np.median(abs_sigma_changes))) cycle_stats['max_sigma_change'].append(float(abs_sigma_changes.max())) cycle_stats['mean_r2'].append(float(current_r2[r2_mask].mean())) cycle_stats['mean_r2_improvement'].append(float(r2_improvements.mean())) cycle_stats['median_r2_improvement'].append(float(np.median(r2_improvements))) if verbose: print(f" Improved: {cycle_improvements}/{n_optimize} sites ({100*cycle_improvements/n_optimize:.1f}%)") print(f" Position changes: {n_position_changes} sites ({100*n_position_changes/n_optimize:.1f}%)") print(f" Mean |Δσ|: {abs_sigma_changes.mean():.3f} mm") print(f" Mean R²: {current_r2[r2_mask].mean():.4f}") print(f" Mean ΔR²: {r2_improvements.mean():.4f}") # Stop if converged if cycle_improvements == 0 and outer_iter > 0: if verbose: print(f"\nConverged after {outer_iter + 1} cycles") break # Compute final beta and baseline optimized_params = { 'vertex_indices': current_vertices, 'sigma': current_sigmas, 'beta': gf.gridsearch_params[:, 2].copy(), 'baseline': gf.gridsearch_params[:, 3].copy(), 'r2': current_r2, 'cycle_statistics': cycle_stats # NEW: Add cycle-wise statistics } # Recompute beta/baseline for optimized sites for voxel_idx in sites_to_optimize: vertex_idx = int(current_vertices[voxel_idx]) sigma = float(current_sigmas[voxel_idx]) target_ts = target_data[voxel_idx] distances_np = distance_matrix[vertex_idx, :] weights = np.exp(-(distances_np ** 2) / (2 * sigma ** 2)) cf_ts = np.dot(weights, source_data) X = np.column_stack([cf_ts, np.ones_like(cf_ts)]) params = np.linalg.lstsq(X, target_ts, rcond=None)[0] optimized_params['beta'][voxel_idx] = float(params[0]) optimized_params['baseline'][voxel_idx] = float(params[1]) tf.keras.backend.clear_session() if verbose: print(f"\n{'='*60}") print("FINAL RESULTS (from initial grid search)") print(f"{'='*60}") position_changed = (optimized_params['vertex_indices'][r2_mask] != initial_vertices[r2_mask]) sigma_change = np.abs(optimized_params['sigma'][r2_mask] - initial_sigmas[r2_mask]) r2_improvement = optimized_params['r2'][r2_mask] - initial_r2[r2_mask] print(f"\nTotal position changes:") print(f" Sites moved: {position_changed.sum()} ({100*position_changed.sum()/n_optimize:.1f}%)") print(f"\nTotal sigma changes:") print(f" Mean |Δσ|: {sigma_change.mean():.3f} mm") print(f" Median |Δσ|: {np.median(sigma_change):.3f} mm") print(f" Max |Δσ|: {sigma_change.max():.3f} mm") print(f"\nTotal R² improvements:") print(f" Mean ΔR²: {r2_improvement.mean():.4f}") print(f" Median ΔR²: {np.median(r2_improvement):.4f}") print(f" Sites improved: {(r2_improvement > 0).sum()} ({100*(r2_improvement > 0).sum()/n_optimize:.1f}%)") print(f" Mean final R²: {optimized_params['r2'][r2_mask].mean():.4f}") print(f"{'='*60}") return optimized_params
## Joint optimzation cycles
[docs]def plot_convergence_summary_table(cycle_stats): """Create a summary table of convergence statistics.""" import pandas as pd import matplotlib.pyplot as plt # Create DataFrame df = pd.DataFrame(cycle_stats) # Add computed columns df['convergence_rate'] = df['percent_improved'].diff().fillna(0) # Create figure fig, ax = plt.subplots(figsize=(14, max(4, len(df) * 0.4)), dpi=150) ax.axis('tight') ax.axis('off') # Format DataFrame for display display_df = df[[ 'cycle', 'n_improved', 'percent_improved', 'n_position_changes', 'percent_position_changes', 'mean_sigma_change', 'mean_r2', 'mean_r2_improvement' ]].copy() display_df.columns = [ 'Cycle', 'N Improved', '% Improved', 'N Pos. changes', '% Pos. changes', 'Mean |Δσ| (mm)', 'Mean R²', 'Mean ΔR²' ] # Format numeric columns display_df['% Improved'] = display_df['% Improved'].map('{:.2f}%'.format) display_df['% Pos. changes'] = display_df['% Pos. changes'].map('{:.2f}%'.format) display_df['Mean |Δσ| (mm)'] = display_df['Mean |Δσ| (mm)'].map('{:.3f}'.format) display_df['Mean R²'] = display_df['Mean R²'].map('{:.4f}'.format) display_df['Mean ΔR²'] = display_df['Mean ΔR²'].map('{:.5f}'.format) # Create table table = ax.table(cellText=display_df.values, colLabels=display_df.columns, cellLoc='center', loc='center', bbox=[0, 0, 1, 1]) table.auto_set_font_size(False) table.set_fontsize(10) table.scale(1, 2) # Style header for i in range(len(display_df.columns)): cell = table[(0, i)] cell.set_facecolor('#2E86AB') cell.set_text_props(weight='bold', color='white') # Alternate row colors for i in range(1, len(display_df) + 1): for j in range(len(display_df.columns)): cell = table[(i, j)] if i % 2 == 0: cell.set_facecolor('#F0F0F0') else: cell.set_facecolor('#FFFFFF') plt.title('Joint-optimization summary', fontsize=14, fontweight='bold', pad=20) return fig
## Compare two optimization strategies
[docs]def compare_cf_results( results_A, results_B, ecc_pRF, pol_pRF, target_roi_mask, flatmaps, colors_ecc, colors_polar, h='lh', r2_threshold=0.1, ecc_div=2, figsize=(16, 8), dpi=300 ): """Compare connective field results from grid search vs optimization. Args: results_A (dict): Dictionary with keys 'centers', 'sigma', 'r2' results_B (dict): Dictionary with keys 'centers', 'sigma', 'r2' ecc_pRF (np.array): Eccentricity values for source vertices pol_pRF (np.array): Polar angle values for source vertices sub_target_mask (np.array): Boolean mask for target vertices flatmaps (dict): Dictionary of flatmap objects colors_ecc (dict): Eccentricity color palette colors_polar (dict): Polar angle color palette h (str): Hemisphere ('lh' or 'rh') r2_threshold (float): R² threshold for visualization figsize (tuple): Figure size dpi (int): Figure DPI Returns: matplotlib.figure.Figure: Comparison figure """ # Create figure with two rows (grid search top, optimized bottom) # Add extra column for row labels fig = plt.figure(figsize=figsize, dpi=dpi) gs = fig.add_gridspec(2, 5, width_ratios=[1, 1, 1, 1, 0.15]) CF_ecc_A = ecc_pRF[results_A['centers'].astype(int)] CF_pol_A = pol_pRF[results_A['centers'].astype(int)] sigma_A = results_A['sigma'] r2_A = results_A['r2'] CF_ecc_B = ecc_pRF[results_B['centers'].astype(int)] CF_pol_B = pol_pRF[results_B['centers'].astype(int)] sigma_B = results_B['sigma'] r2_B = results_B['r2'] # Plot both rows for row_idx, (CF_ecc, CF_polar, sigma, r2, row_label) in enumerate([ (CF_ecc_A, CF_pol_A, sigma_A, r2_A, results_A['opType']), (CF_ecc_B, CF_pol_B, sigma_B, r2_B, results_B['opType']) ]): # Create maps ecc_map = np.full(target_roi_mask.shape[0], np.nan) ecc_map[target_roi_mask] = CF_ecc polar_map = np.full(target_roi_mask.shape[0], np.nan) polar_map[target_roi_mask] = CF_polar sigma_map = np.full(target_roi_mask.shape[0], np.nan) sigma_map[target_roi_mask] = sigma r2_map = np.full(target_roi_mask.shape[0], np.nan) r2_map[target_roi_mask] = r2 # Create mask for R² threshold threshold_mask = r2_map >= r2_threshold # Get axes for this row left_ax = fig.add_subplot(gs[row_idx, 0]) left_middle_ax = fig.add_subplot(gs[row_idx, 1]) right_middle_ax = fig.add_subplot(gs[row_idx, 2]) right_ax = fig.add_subplot(gs[row_idx, 3]) label_ax = fig.add_subplot(gs[row_idx, 4]) # Eccentricity plot ny.cortex_plot( flatmaps[h], axes=left_ax, color=ecc_map, cmap=colors_ecc['matplotlib_cmap'], mask=threshold_mask, vmin=np.nanmin(ecc_map), vmax=np.nanmax(ecc_map)/ecc_div, ) left_ax.set_aspect('equal') if row_idx == 0: # Only add column titles to top row left_ax.set_title('CF eccentricity', pad=8, fontsize=20) left_ax.axis('off') # Polar angle plot ny.cortex_plot( flatmaps[h], axes=left_middle_ax, color=polar_map, cmap=colors_polar['matplotlib_cmap'], mask=threshold_mask, ) left_middle_ax.set_aspect('equal') if row_idx == 0: left_middle_ax.set_title('CF polar angle', pad=8, fontsize=20) left_middle_ax.axis('off') # CF Size plot size_vmin = np.nanmin(sigma_map) size_vmax = np.nanmax(sigma_map) size_cmap = plt.cm.jet ny.cortex_plot( flatmaps[h], axes=right_middle_ax, color=sigma_map, cmap=size_cmap, mask=threshold_mask, vmin=size_vmin, vmax=size_vmax, ) right_middle_ax.set_aspect('equal') if row_idx == 0: right_middle_ax.set_title('CF size', pad=8, fontsize=20) right_middle_ax.axis('off') # Variance explained plot varex_cmap = plt.cm.inferno ny.cortex_plot( flatmaps[h], axes=right_ax, color=r2_map, cmap=varex_cmap, mask=threshold_mask, vmin=0, vmax=1, ) right_ax.set_aspect('equal') if row_idx == 0: right_ax.set_title('Variance explained', pad=8, fontsize=20) right_ax.axis('off') # Add vertical row label on the right label_ax.axis('off') label_ax.text(0.5, 0.5, row_label, rotation=90, fontsize=14, fontweight='bold', ha='center', va='center', transform=label_ax.transAxes) # Add colorbars only to bottom row if row_idx == 1: # Eccentricity inset ecc_inset = inset_axes(left_ax, width="50%", height="50%", loc="lower right", borderpad=-6) ecc_inset.set_aspect('equal') ecc_inset.set_xlim(-1.5, 1.5) ecc_inset.set_ylim(-1.5, 1.5) ecc_inset.text(0.5, -0.05, r'CF center $\rho\ (\mathit{deg})$', ha='center', va='top', fontsize=14, transform=ecc_inset.transAxes) ecc_inset.set_axis_off() num_ecc_colors = len(colors_ecc["hex"]) for i, color in enumerate(colors_ecc["hex"]): inner_r = i / num_ecc_colors outer_r = (i + 1) / num_ecc_colors ring = Wedge((0, 0), outer_r, 0, 360, width=outer_r - inner_r, color=color) ecc_inset.add_patch(ring) # Polar angle inset polar_inset = inset_axes(left_middle_ax, width="40%", height="40%", loc="lower right", borderpad=-6) polar_inset.set_aspect('equal') polar_inset.set_axis_off() polar_inset.pie([1]*len(colors_polar["hex"]), colors=colors_polar["hex"], startangle=180, counterclock=False) polar_inset.text(0.5, -0.05, r'CF center $\theta\ (\mathit{rad})$', ha='center', va='top', fontsize=14, transform=polar_inset.transAxes) # CF Size colorbar sigma_rect_ax = inset_axes(right_middle_ax, width="30%", height="10%", loc="lower right", borderpad=-3) gradient = np.linspace(0, 1, 256).reshape(1, -1) gradient = np.vstack((gradient, gradient)) sigma_rect_ax.imshow(gradient, aspect='auto', cmap=size_cmap, extent=[0, 1, 0, 1]) sigma_rect_ax.text(0, -0.3, f'{size_vmin:.2f}', ha='left', va='top', fontsize=10) sigma_rect_ax.text(1, -0.3, f'{size_vmax:.2f}', ha='right', va='top', fontsize=10) sigma_rect_ax.text(0.5, 1.3, r'CF $\sigma\ (\mathit{mm})$', ha='center', va='bottom', fontsize=14, transform=sigma_rect_ax.transAxes) sigma_rect_ax.axis('off') # Variance explained colorbar varex_rect_ax = inset_axes(right_ax, width="30%", height="10%", loc="lower right", borderpad=-3) varex_rect_ax.imshow(gradient, aspect='auto', cmap=varex_cmap, extent=[0, 1, 0, 1]) varex_rect_ax.text(0, -0.3, '0', ha='left', va='top', fontsize=12) varex_rect_ax.text(1, -0.3, '1', ha='right', va='top', fontsize=12) varex_rect_ax.text(0.5, 1.3, r'$\mathit{r}\!{}^2$', ha='center', va='bottom', fontsize=14, transform=varex_rect_ax.transAxes) varex_rect_ax.axis('off') plt.tight_layout() # Print comparison statistics print("\n" + "="*60) print("Optimization approach comparison") print("="*60) mask = r2_A > r2_threshold print(f"\nR² threshold: {r2_threshold}") print(f"Sites above threshold: {mask.sum()}") print(f"\n") print(results_A['opType']) print(f"Sigma range: {sigma_A[mask].min():.2f} - {sigma_A[mask].max():.2f} mm") print(f"Mean sigma: {sigma_A[mask].mean():.2f} mm") print(f"Mean R²: {r2_A[mask].mean():.4f}") print(f"\n") print(results_B['opType']) print(f"Sigma range: {sigma_B[mask].min():.2f} - {sigma_B[mask].max():.2f} mm") print(f"Mean sigma: {sigma_B[mask].mean():.2f} mm") print(f"Mean R²: {r2_B[mask].mean():.4f}") print("\n--- Improvements ---") sigma_change = np.abs(sigma_B[mask] - sigma_A[mask]) r2_improvement = r2_B[mask] - r2_A[mask] print(f"Mean |Δσ|: {sigma_change.mean():.3f} mm") print(f"Mean ΔR²: {r2_improvement.mean():.4f}") print(f"Median ΔR²: {np.median(r2_improvement):.4f}") print(f"Sites with improved R²: {(r2_improvement > 0).sum()} ({100*(r2_improvement > 0).sum()/mask.sum():.1f}%)") print("="*60) return fig