Source code for tfrStats.plot_dmvtfr_stats

import numpy as np
from statsmodels.distributions.empirical_distribution import ECDF
import matplotlib.pyplot as plt
from scipy.interpolate import interp2d
import matplotlib.colors as colors
from tfrStats.stats_tfrs_davg import stats_tfrs_davg as stats_tfrs_davg  


import warnings
warnings.filterwarnings('ignore')

[docs]def plot_dmvtfr_stats(cond, tfr_emp, tfr_null, fband, alpha, correction): """ Plot empirical Multi-variate TFR and stats results This functions use load_uv_tfrs, as well as optionally get_pvals_minmax, get_pvals_whole and (also optionally) cluster_correction to plot the empirical TFR, the p-values and the corrected threshold. Correction for multiple comparisons is already taken into account by get_pvals_minmax and get_pvals_whole. Optionally, cluster_correction corrects the p-values for multiple comparisons using a distance threshold for neighbours frequencies and time bins if they are alltogheter above alpha. .. todo:: * Handle parameters with dictionary. :param string input_path: path to the .npz file. :param in condition: condition index (i.e. 0, 1, 2, 3). :param int svar: spectral power or GPR (not implemented here). :param int fband: frequency band index (i.e. low, high, higher). :param int obs: [nullType, percentile], two integeres: 0 for min-max, 1 for whole, 0-100 percentile. :param int correction: 1 for p-values, 2 for cluster corrected p-values. :param int cluster_size: cluster size. :param float alpha: alpha. :return: empirical time frequency representation n_conds x n_sites x n_freqs x n_time (i.e. 30, 12, 16, 113). :return: null time frequency representation (i.e. 30, 12, 16, 113 or 1000, 30, 12, 16, 2). :rtype: float @author: Nicolas Gravel, 19.09.2023 """ tps = [57,113,141,140] # time windows fps = [19,16,11,1] # frequency bins lp = [2, 20, 80] # low cut hp = [20, 80, 200] # high cut twindow = [65, 80] # window in the plot stats_range = [400, 1000] # range for thresholding (interval or "cluster" to compute the threshold) ups = 4 # upsampling in figure cmap = 'cubehelix_r' cnorm = 1 coloff = 0.5 overlay_range = [1,-1] # range for overlay coverage alpha = 0.05 cnorm_range = [400, 1000] ## helper function used by plot_stats to noramlize colormap ranges def coloroffset(min_val, max_val, k): if 0 <= k <= 1: # Ensure k is between 0 and 1 point = min_val + k*(max_val - min_val) #print(f'For k={k}, the point in the range {min_val}-{max_val} is: {point}') #else: #print("Error: k must be between 0 and 1") return point ## Plot TFR across sites fig, ax = plt.subplots(nrows=2, ncols=1,figsize=(6,4)) # indices for plotting x = np.linspace(start = -800, stop = 2000, num = tps[fband])# time vector t0 = np.searchsorted(x, stats_range[0],side='left', sorter=None) # time index for induced power period start tf = np.searchsorted(x, stats_range[1],side='left', sorter=None) # time index for induced power period end y = np.linspace(lp[fband], hp[fband], fps[fband]) y2 = np.linspace(lp[fband], hp[fband], fps[fband]*ups) X, Y = np.meshgrid(x, y) x2 = np.linspace(start = -800, stop = 2000, num = 280) X2, Y2 = np.meshgrid(x2, y2) # plot empirical TFR print(tfr_emp.shape) print(tfr_null.shape) #tfr_emp_ = np.squeeze(np.nanmean(tfr_emp,axis=0)) gavg = np.squeeze(np.nanmean(tfr_emp,axis=0)) gavg[np.isnan(gavg)] = 0 # Cross-frequency TFR x = np.linspace(start = -800, stop = 2000, num = tps[fband]) tt0 = np.searchsorted(x,stats_range[0],side='left', sorter=None) ttf = np.searchsorted(x,stats_range[1],side='left', sorter=None) pwr = np.mean(gavg[:,tt0:ttf],axis=1) #print(pwr.shape) x = np.linspace(lp[fband], hp[fband], num = fps[fband]) if fband == 0: peak = np.argmax(pwr); #print(peak) sigma = 2 if fband == 2: peak = np.argmax(pwr[0:5]); #print(peak) sigma = 2 else: peak = np.argmax(pwr); sigma = 2 pk = peak.astype(int) #print('peak frequency : ', x[pk]) #print('peak power :', pwr[pk]) #peaks[contrast_idx,fband,0] = x[pk] #tfr emp : (30, 12, 16, 113) #tfr null (1000, 30, 12, 16, 2) if pk-sigma<=0: pwr_avg = np.mean(pwr[pk:pk+2*sigma]) print('peak frequency range : ', x[pk+2*sigma]) print('power average within peak:', pwr_avg) davg = np.squeeze(np.nanmean(tfr_emp[:,pk:pk+2*sigma,:],axis=1)) davg_null = np.squeeze(np.nanmean(tfr_null[:,:,pk:pk+2*sigma,:],axis=1)) null = tfr_null[:,:,pk:pk+2*sigma,:] elif pk+sigma>=fps[fband]: pwr_avg = np.mean(pwr[pk-2*sigma:pk]) print('peak frequency range : ', x[pk-2*sigma]) print('power average within peak:', pwr_avg) davg = np.squeeze(np.nanmean(tfr_emp[:,pk-2*sigma:pk,:],axis=1)) davg_null = np.squeeze(np.nanmean(tfr_null[:,:,pk-2*sigma:pk,:],axis=1)) null = tfr_null[:,:,pk-2*sigma:pk,:] else: print('peak frequency range : ', x[pk-sigma], x[pk+sigma]) pwr_avg = np.mean(pwr[pk-sigma:pk+sigma]) print('power average within peak:', pwr_avg) #print(tfr_emp.shape) davg = np.squeeze(np.nanmean(tfr_emp[:,pk-sigma:pk+sigma,:],axis=1)) davg_null = np.squeeze(np.nanmean(tfr_null[:,:,pk-sigma:pk+sigma,:],axis=1)) null = tfr_null[:,:,pk-sigma:pk+sigma,:] davg[np.isnan(davg)] = 0 davg_null[np.isnan(davg_null)] = 0 #print('depth average :', davg.shape) #print('depth null-average :', davg_null.shape) # e.g. depth null-average : (1000, 12, 4, 2) x = np.linspace(start = -800, stop = 2000, num = tps[fband]) y = np.linspace(start=-550, stop=550, num=12).astype(int) x2 = np.linspace(start = -800, stop = 2000, num = 280) y2 = np.linspace(start=-550, stop=550, num=12*ups).astype(int) X, Y = np.meshgrid(x, y) X2, Y2 = np.meshgrid(x2, y2) #print(davg.shape) #davg = np.mean(davg,axis=1) print(davg.shape) f = interp2d(x, y, np.flipud(davg), kind='linear') # Color map normalization tt0 = np.searchsorted(x,cnorm_range[0],side='left', sorter=None) ttf = np.searchsorted(x,cnorm_range[1],side='left', sorter=None) tfrange = davg[:,tt0:ttf] _min = np.min(np.min(tfrange.flatten())) _max = np.max(np.max(tfrange.flatten())) #print('min =',_min,'max =',_max) if cnorm == 1 : vmin = 0 #_min vmax = 0.7 #_max vcenter = coloroffset(vmin, vmax, coloff) norm = colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax) TFR_emp = f(x2, y2) im_spwr = ax[0].pcolormesh(X2[:,twindow[0]:-twindow[1]], Y2[:,twindow[0]:-twindow[1]], TFR_emp[:,twindow[0]:-twindow[1]] , cmap=cmap,norm=norm) # plot stats stats = stats_tfrs_davg(null,davg,correction) #print(stats.shape) #f = interp2d(x, y, stats, kind='linear') #TFR_pvals = f(x2, y2) f = interp2d(x, y, np.flipud(stats), kind='linear') TFR_pvals = f(x2, y2) THR = TFR_pvals <= alpha im_pvals = ax[1].pcolormesh(X2[:,twindow[0]:-twindow[1]], Y2[:,twindow[0]:-twindow[1]], TFR_pvals[:,twindow[0]:-twindow[1]]) #THR = TFR_pvals <= alpha #alpha ax[0].contour(X2[overlay_range[0]:overlay_range[1],twindow[0]:-twindow[1]], Y2[overlay_range[0]:overlay_range[1],twindow[0]:-twindow[1]], THR[overlay_range[0]:overlay_range[1],twindow[0]:-twindow[1]], origin='upper', colors='dodgerblue', linestyles='solid', linewidths=0.5) cbar = plt.colorbar(im_spwr,cax = fig.add_axes([0.95, 0.6, 0.02, 0.15]),extend='both') cbar.ax.tick_params(labelsize=10) cbar.set_label('Spearman rho',fontsize=10) cbar = plt.colorbar(im_pvals,cax = fig.add_axes([0.95, 0.2, 0.02, 0.15]),extend='both') cbar.ax.tick_params(labelsize=10) cbar.set_label('p-value',fontsize=10) ax[1].set_xlabel('Time (ms)', fontsize=12) ax[0].set_ylabel('frequency (Hz)', rotation=90, fontsize=10) ax[1].set_ylabel('frequency (Hz)', rotation=90, fontsize=10) ax[0].title.set_text('RDM reliability obtained using different stimulus choices') #ax[1].title.set_text('p-values') txt=str('Cutoff (blue outline) is corrected across ' + correction) fig.text(0.5, -0.06, txt, ha='center') return