# Author: Tyler de Zeeuw
VERSION = "1.0.0"

# Import the required packages
try:
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import mne
    import os
    import requests
    import mne_nirs
    import zipfile
    import io
    from mne.annotations import Annotations
    from mne_nirs.visualisation import plot_glm_group_topo, plot_glm_surface_projection
    import matplotlib as mpl
    from nilearn.plotting import plot_design_matrix
    from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx
    from mne_nirs.experimental_design import make_first_level_design_matrix
    from mne_nirs.statistics import run_glm, statsmodels_to_results
    from mne.preprocessing.nirs import beer_lambert_law, optical_density
    import statsmodels.formula.api as smf
    import seaborn as sns
    import pandas as pd
    from matplotlib.colors import LinearSegmentedColormap
    import xlrd
    from scipy.stats import ttest_1samp
    from mne_nirs.io.fold import fold_channel_specificity
    import qtpy
    import pyvistaqt

except Exception as e:
    print(f"Whoops! A required software package is missing! Error: {e}")
    print(f"This can be resolved by running the command 'pip3 install <package>'.")
    sys.exit(1)

BASE_SNIRF_FOLDER = None
SNIRF_SUBFOLDERS = None
OPTODE_FILE_PATH = None
OUTPUT_FOLDER_LOCATION = None

TARGET_ACTIVITY = None
TARGET_CONTROL = None

DOWNSAMPLE = None
OPTODE_FILE = None
SHORT_CHANNEL = None
SNR = None
SCI = None
BAD_DISTANCE_CHANNELS = None
FILTER = None
REJECT_PAIRS = None

DOWNSAMPLE_FREQUENCY = None
FORCE_DROPPED_CHANNELS = None
FORCE_DROP_ANNOTATIONS = None
SOURCE_DETECTOR_SEPARATOR = None
SHORT_CHANNEL_THRESH = None
EPOCH_REJECT_CRITERIA_THRESH = None
TIME_MIN_THRESH = None
TIME_MAX_THRESH = None
BAD_SNR_THRESH = None
BAD_SCI_THRESH = None
MAX_BAD_CHANNELS = None
LONG_CHANNEL_THRESH = None
PPF = None
FILTER_LOW_PASS = None
FILTER_HIGH_PASS = None
EPOCH_PAIR_TOLERANCE_WINDOW = None

DRIFT_MODEL = None
DURATION_BETWEEN_ACTIVITIES = None
HRF_MODEL = None
SHORT_CHANNEL_REGRESSION = None
STIM_DURATION = None
N_JOBS = None
ROI_GROUP_1 = None
ROI_GROUP_2 = None
ROI_GROUP_1_NAME = None
ROI_GROUP_2_NAME = None

ABS_T_VALUE = None
ABS_THETA_VALUE = None
ABS_CONTRAST_THETA_VALUE = None
ABS_CONTRAST_T_VALUE = None
ABS_P_T_GRAPH_VALUE = None
P_THRESHOLD = None
BRAIN_DISTANCE = None
BRAIN_MODE = None

SAVE_IMAGES = None

USERNAME ="fNIRS_file"
PASSWORD =r"mDl`y]7/oMk|Yp\jy{Q^41#"

# Remote resources
REMOTE_VERSION_URL = "https://research.dezeeuw.ca/version.txt"
REMOTE_FOLDER_URL = "https://research.dezeeuw.ca/fNIRS_v{version}.zip"

def check_remote_version(branch):
    try:
        response = requests.get(REMOTE_VERSION_URL, auth=(USERNAME, PASSWORD))
        if response.status_code == 200:
            remote_version = response.text.strip()
            if remote_version != VERSION:
                print(f"A new version is available: {remote_version}")
                if branch == "current":
                    print("Skipping update due to the branch being set to current.")
                return remote_version
            else:
                print("You are using the latest version.")
                return None
        else:
            raise Exception("The file was not found on the server.")
        
    except Exception as e:
        print(f"Error checking version: {e}")
        print("Skipping update.")
        return None


def download_new_folder(remote_version):
    print(f"Downloading the newest version...")

    folder_url = REMOTE_FOLDER_URL.format(version=remote_version)
    current_dir = os.getcwd()
    extract_folder = os.path.join(current_dir, f"fNIRS_v{remote_version}")

    try:
        response = requests.get(folder_url, auth=(USERNAME, PASSWORD))
        if response.status_code != 200:
            print(f"Failed to download the update: HTTP {response.status_code} - {response.reason}")
            return False
        
        with zipfile.ZipFile(io.BytesIO(response.content)) as z:
            z.extractall(extract_folder)
        
        print(f"Downloaded and extracted new version to: {extract_folder}")
        
        from IPython.display import display, Markdown

        display(Markdown(f"""# <b>This file has been updated to version <span style="color: green;">{remote_version}</span>. Please <span style="color: red;">close this window</span> and open the new version located at: {extract_folder}."""))

        return True
        
    except Exception as e:
        print(f"Error downloading the newest version: {e}")
        print("Skipping update due to an error.")
        return False



def check_for_update(branch):
    if branch == "main":
        remote_version = check_remote_version(branch)
        if remote_version:
            status = download_new_folder(remote_version)
            return status
    if branch == "current":
        remote_version = check_remote_version(branch)
        return False




REQUIRED_KEYS = [
    "BASE_SNIRF_FOLDER", "SNIRF_SUBFOLDERS", "OPTODE_FILE_PATH", "OUTPUT_FOLDER_LOCATION",
    "TARGET_ACTIVITY", "TARGET_CONTROL",
    "DOWNSAMPLE", "OPTODE_FILE", "SHORT_CHANNEL", "SNR", "SCI",
    "BAD_DISTANCE_CHANNELS", "FILTER", "REJECT_PAIRS",
    "DOWNSAMPLE_FREQUENCY", "FORCE_DROPPED_CHANNELS", "FORCE_DROP_ANNOTATIONS",
    "SOURCE_DETECTOR_SEPARATOR", "SHORT_CHANNEL_THRESH", "EPOCH_REJECT_CRITERIA_THRESH",
    "TIME_MIN_THRESH", "TIME_MAX_THRESH", "BAD_SNR_THRESH", "BAD_SCI_THRESH",
    "MAX_BAD_CHANNELS", "LONG_CHANNEL_THRESH", "PPF", "FILTER_LOW_PASS",
    "FILTER_HIGH_PASS", "EPOCH_PAIR_TOLERANCE_WINDOW",
    "DRIFT_MODEL", "DURATION_BETWEEN_ACTIVITIES", "HRF_MODEL", "SHORT_CHANNEL_REGRESSION",
    "STIM_DURATION", "N_JOBS", "ROI_GROUP_1", "ROI_GROUP_2", "ROI_GROUP_1_NAME", "ROI_GROUP_2_NAME",
    "ABS_T_VALUE", "ABS_THETA_VALUE", "ABS_CONTRAST_THETA_VALUE", "ABS_CONTRAST_T_VALUE",
    "ABS_P_T_GRAPH_VALUE", "P_THRESHOLD", "BRAIN_DISTANCE", "BRAIN_MODE",
    "SAVE_IMAGES"
]

def set_config(config):
    '''Updates global variables from config dictionary after validation.'''
    missing = [k for k in REQUIRED_KEYS if k not in config]
    if missing:
        raise ValueError(f"Missing config keys: {missing}")

    globals().update(config)
    print("[Config] Configuration successfully set.")



def run_groups():

    all_results = {}

    for folder, stim_duration in zip(SNIRF_SUBFOLDERS, STIM_DURATION):
        full_path = os.path.join(BASE_SNIRF_FOLDER, folder)
        raw_haemo, df_roi, df_cha, df_con = process_folder(full_path, stim_duration)
        all_results[folder] = (df_roi, df_cha, df_con)

    return all_results, raw_haemo



def load_snirf(file_path, drop_prefixes):
    '''Method to load snirf data from file.\n
    Input:\n
    file_path (str) - Location of the snirf file\n
    drop_prefixes (list) - List containg any channels to be dropped\n
    Output:\n
    raw (RawSNIRF) - The loaded data'''

    try:
        # Read the snirf file
        raw = mne.io.read_raw_snirf(file_path, verbose=True)
        print("Any error warning about 2D positions can be safely ignored, as we add the 3D positions later from the optode file.")

        # Load the data and return it
        raw.load_data()

        # If the user forcibly dropped channels, remove them now before any processing occurs
        if drop_prefixes:
            channels_to_drop = [ch for ch in raw.ch_names if any(ch.startswith(prefix) for prefix in drop_prefixes)]
            raw.drop_channels(channels_to_drop)
            print("Dropped channels:", channels_to_drop)

        # If the user wants to downsample, do it right away
        if DOWNSAMPLE:
            raw.resample(DOWNSAMPLE_FREQUENCY)
            
        return raw

    except Exception as e:
        print(f"Whoops! Can't read the snirf file! Error: {e}")
        sys.exit(1)



def get_optode_coordinates(file_path=None):
    '''Method to load 3D optode positions from file.\n
    Input:\n
    file_path (str) - (optional) Location of the optode file. Default is OPTODE_FILE_PATH\n
    Output:\n
    sources (dict) - All sources found\n
    detectors (dict) - All detectors found'''

    if file_path is None:
        file_path = OPTODE_FILE_PATH

    try:
        # Dictionaries to store our sources and detectors
        sources = {}
        detectors = {}
        
        # Open the file and read the line
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue

                # Split to get the source/detector and the coords
                parts = line.split()

                # Remove the colon from the string
                key = parts[0].lower().replace(':', '')
                coords = np.array([float(x) / 1000 for x in parts[1:]])
                
                # Add to the corresponding dictionary
                if key.startswith('s'):
                    sources[key] = coords
                elif key.startswith('d'):
                    detectors[key] = coords
                                   
        # Display sources and detectors
        print("Parsed sources:  ", sources)
        print("Parsed detectors:", detectors)
        return sources, detectors
    
    except Exception as e:
        print(f"Whoops! Can't read the optode file! Error: {e}")
        sys.exit(1)



def update_channel_location(raw, detectors, sources):
    '''Method to update the channel locations in our loaded snirf file. This does not change the file on disk.\n
    Input:\n
    raw (RawSNIRF) - The loaded snirf data\n
    sources (dict) - All sources found\n
    detectors (dict) - All detectors found'''

    try:
        print("Updating channels with the text file positions...")

        # Get each channel from the file
        for ch in raw.info['chs']:
            ch_name = ch['ch_name'].lower()
            parts = ch_name.split()

            # Split into sources and detectors based on the seperator
            s_part, d_part = parts[0].split(SOURCE_DETECTOR_SEPARATOR)
            
            print(f"Updating {ch['ch_name']}:")
            print(f"Looking for detector: {d_part}, source: {s_part}")
            
            # Update source position
            if s_part in sources:
                print(f"Found source {s_part}: {sources[s_part]}")
                print(f"Old source pos:   {ch['loc'][3:6]} > New: {sources[s_part]}")
                ch['loc'][3:6] = sources[s_part]
            else:
                print(f"Source {s_part} not found in text file!")

            # Update detector position
            if d_part in detectors:
                print(f"Found detector {d_part}: {detectors[d_part]}")
                print(f"Old detector pos: {ch['loc'][6:9]} > New: {detectors[d_part]}")
                ch['loc'][6:9] = detectors[d_part]
            else:
                print(f"Detector {d_part} not found in text file!")
            
            # Calculate and update midpoints
            midpoint = (ch['loc'][3:6] + ch['loc'][6:9]) / 2
            ch['loc'][0:3] = midpoint
            print(f"Updated channel midpoint: {midpoint}")
            
            ch['coord_frame'] = 1

        # Display the updated channels
        print("Updating completed.")
        print("Verifying updates...")
        for i, ch in enumerate(raw.info['chs'][:]):
            print(f"\nChannel {i+1}: {ch['ch_name']}")
            print(f"Updated source pos:   {ch['loc'][3:6]}")
            print(f"Updated detector pos: {ch['loc'][6:9]}")
            print(f"Updated midpoint pos: {ch['loc'][:3]}")
            
    except Exception as e:
        print(f"Whoops! Something went wrong updating the channel positions! Error: {e}")
        sys.exit(1)



def get_duration_of_snirf(raw):
    '''Method to get the duration of the snirf file in seconds This method is frequency independant.\n
    Input:\n
    raw (RawSNIRF) - The loaded snirf data\n
    Output:\n
    total_duration (float) - The total duration of the snirf data in seconds'''

    try:
        # Get rows of the file and the frequency the file was recored as
        n_samples = raw.n_times
        sfreq = raw.info['sfreq']

        # Divide to get time in seconds, display teh results, and return the total duration
        total_duration = n_samples / sfreq
        print("Freguency of the file:", sfreq)
        print("Total duration of the file in seconds:", total_duration)
        return total_duration

    except Exception as e:
        print(f"Whoops! Something went wrong calculating the length of the data in the file! Error: {e}")
        sys.exit(1)



def calculate_short_and_long_channels(raw, short_channel_thresh=None, long_channel_thresh=None):
    '''Method to calculate the short and long channels in the snirf file.\n
    Input:\n
    raw (RawSNIRF) - The loaded snirf data\n
    short_channel_thresh (float) - (optional) The shortest a normal channel is with anything less being classified as short. Default is SHORT_CHANNEL_THRESH
    long_channel_thresh (float) - (optional) The longest a normal channel is with anything more being classified as bad. Default is LONG_CHANNEL_THRESH
    Output:\n
    raw_minus_short (RawSNIRF) - Modified raw with any short channels removed\n
    raw_only_short (RawSNIRF) - Modified raw with any non-short channels removed\n
    bad_channels_based_on_distance (list) - List containing the names of the long channels\n
    short_channel_idx (string) - Contains the source-detector pair of the short channel'''

    if short_channel_thresh is None:
        short_channel_thresh = SHORT_CHANNEL_THRESH
    if long_channel_thresh is None:
        long_channel_thresh = LONG_CHANNEL_THRESH

    # Gets the channels and distances
    picks = mne.pick_types(raw.info, fnirs=True)
    dists = mne.preprocessing.nirs.source_detector_distances(raw.info, picks=picks)
    print("Channels:", picks)
    print("Distances:", dists)

    # Checks for long channels and mark them as bad
    bad_channels_based_on_distance = [raw.info['ch_names'][pick] for pick, dist in zip(picks, dists) if dist > long_channel_thresh]
    print("Bad distance channels:", bad_channels_based_on_distance)

    # Get the names of the short channels
    if SHORT_CHANNEL:
        short_channel = ""
        for count, i in enumerate(mne.preprocessing.nirs.short_channels(raw.info, short_channel_thresh)):
            if i == True:
                short_channel += raw.ch_names[count] + " "
                print("Short channel identified:", raw.ch_names[count])

        # Remove any channels that do not meet the short channel threshold
        raw_minus_short = raw.copy().pick(picks[dists > short_channel_thresh])

        # Get only the channels that meet the short channel threshold
        raw_only_short = raw.copy().pick(picks[dists < short_channel_thresh])

        # Calculate only the source-detector pair of the short channel
        short_channel_idx = short_channel.split(' ')[0]
    
        return raw_minus_short, raw_only_short, bad_channels_based_on_distance, short_channel_idx
    
    # User stated that a short channel does not exist in the data
    else:
        return raw.copy(), None, bad_channels_based_on_distance, 0



def create_raw_graph(raw, total_duration, file_name, output_folder=None, save_images=None):
    '''Method to plot the raw data for each channel.\n
    Input:\n
    raw (RawSNIRF) - The loaded snirf data\n
    total_duration (float) - The total duration of the snirf data in seconds\n
    file_name (string) - The file name of the current file\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    '''

    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES

    # Do we save the image?
    if save_images:
        raw_fig = raw.plot(n_channels=len(raw.ch_names), duration=total_duration, show=False)
        save_path = output_folder + "/1. Raw Data for " + file_name + ".png"
        raw_fig.savefig(save_path)



def calculate_optical_density(raw, total_duration, file_name, output_folder=None, save_images=None):
    '''Method to calculate the optical density data for each channel.\n
    Input:\n
    raw (RawSNIRF) - The loaded snirf data\n
    total_duration (float) - The total duration of the snirf data in seconds\n
    file_name (string) - The file name of the current file\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    raw_od (RawSNIRF) - The optical density data\n'''

    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES
        
    # Calculate the optical density from the raw data
    raw_od = mne.preprocessing.nirs.optical_density(raw)

    # Do we save the image?
    if save_images:
        optical_fig = raw_od.plot(n_channels=len(raw_od.ch_names), duration=total_duration, show=False)
        save_path = output_folder + "/2. Optical Density for " + file_name + ".png"
        optical_fig.savefig(save_path)

    return raw_od



def calculate_snr(raw_od, file_name, snr_thresh=None, output_folder=None, save_images=None):
    '''Method to calculate the signal-to-noise values. Anything above 20.5 is automatically good.\n
    Input:\n
    raw_od (RawSNIRF) - The optical density data\n
    file_name (string) - The file name of the current file\n
    snr_thresh (float) - (optional) dB of signal with anything lower considered noisy. Default is BAD_SNR_THRESH\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    bad_channels (set) - Names of the identified bad channels\n
    '''
    
    # NOTE:
    # The threshold value is important fo ridentifying bad channels
    # I attempted to find what is considered 'good' but got conflicting answers

    # This paper says anything higher than 30??
    # https://www.researchgate.net/figure/Flowchart-of-fNIRS-signals-processing-SNR-signal-noise-ratio-OD-optical-density_fig2_346800142

    # This one got values between 8 and 15?
    # https://www.mdpi.com/2076-3417/12/1/316

    # ChatGPT says higher than 20 good
    # Deepseek agrees with ChatGPT

    # This says -20 to 20? 
    # https://www.u-picardie.fr/ressources-lnfp/confouning-factors-in-event-related-nirs-analysis/
 
    if snr_thresh is None:
        snr_thresh = BAD_SNR_THRESH
    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES

    # Compute the signal-to-noise ratio values
    signal_band=(0.01, 0.5)
    noise_band=(1.0, 10.0)
    raw_od_signal = raw_od.copy().filter(*signal_band, verbose=False)
    raw_od_noise = raw_od.copy().filter(*noise_band, verbose=False)
    signal_power = np.mean(raw_od_signal.get_data()**2, axis=1)
    noise_power = np.mean(raw_od_noise.get_data()**2, axis=1)

    # Calculate the snr using the standard formula for dB
    snr = 10 * np.log10(signal_power / (noise_power + np.finfo(float).eps))

    groups = {}
    for ch in raw_od.ch_names:
        # Look for the space in the channel names and remove the characters after
        # This is so we can get both oxy and deoxy to remove, as they will have the same source and detector
        base = ch.rsplit(' ', 1)[0]
        groups.setdefault(base, []).append(ch)

    # If any of the channels do not meet our threshold, they will get inserted into the bad_channels set
    bad_channels = set()
    for base, ch_list in groups.items():
        if any(s < snr_thresh for s, ch in zip(snr, raw_od.ch_names) if ch in ch_list):
            bad_channels.update(ch_list)

    # Display the channels that failed this check
    print("Channels that failed on SNR!!!:", list(bad_channels))

    # Do we save the image?
    if save_images:
        snr_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained")
        colors = [(0/25, 'red'), (snr_thresh/25, 'red'), ((snr_thresh+.5)/25, 'yellow'), ((snr_thresh+1)/25, 'green'), (25/25, 'green')]
        cmap = LinearSegmentedColormap.from_list('custom_snr_cmap', colors)
        norm = plt.Normalize(vmin=0, vmax=25)
        scatter = ax.scatter(range(len(snr)), snr, c=snr, cmap=cmap, alpha=0.8, s=100, norm=norm)
        ax.set(xlabel="Channel Number",  ylabel="Signal-to-Noise Ratio (dB)", xlim=[0, len(snr)], ylim=[0, 25])
        ax.axhline(snr_thresh, color='black', linestyle='--', alpha=0.3, linewidth=1)
        cbar = snr_fig.colorbar(scatter, ax=ax, label="SNR Thresholds (dB)")
        cbar.set_ticks([0, snr_thresh, snr_thresh+1, 25])
        cbar.set_ticklabels(['0', str(snr_thresh), str(snr_thresh+1), '25'])
        save_path = output_folder + "/3. Signal-to-Noise for " + file_name + ".png"
        snr_fig.savefig(save_path)

    return bad_channels



def calculate_sci(raw_od, file_name, sci_thresh=None, output_folder=None, save_images=None):
    '''Method to calculate the scalp coupling index values. Anything above 0.97 is automatically good.\n
    Input:\n
    raw_od (RawSNIRF) - The optical density data\n
    file_name (string) - The file name of the current file\n
    sci_thresh (float) - (optional) Threshold value with anything lower considered bad connectivity. Default is BAD_SCI_THRESH\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    bad_channels (set) - Names of the identified bad channels'''
    
    # NOTE: 
    # SCI quantifies how well an fNIRS optode is coupled to the scalp by analyzing the presence of cardiac pulsations in the raw signal.

    # From https://opg.optica.org/boe/fulltext.cfm?uri=boe-7-12-5104&id=354588
    # "Previously, we used an SCI threshold of 0.8 to identify channels with acceptable scalp coupling. 
    # However, when the scalp coupling is ideal, fNIRS instruments often exhibit an SCI that approaches the maximum value of 1."

    # From https://www.sciencedirect.com/science/article/abs/pii/S0378595513002803?via%3Dihub
    # Initially, we selected the cleanest channels based on the value of the scalp contact index (SCI) in each channel.
    # With a SCI threshold of 0.75, 34 and 38 channels survived the selection in the left and right hemisphere, respectively."

    # From https://www.cortivision.com/hair-management-tips-for-fnirs/
    # The threshold for an acceptable scalp coupling value is 0.8:
    # Channels with values above 0.9 respectively are marked in green and are signed as good.
    # Channels between values of 0.8 and 0.9 are marked in orange and signed as medium
    # Channels below 0.8 are marked red and signed as bad. This means that they are rejected in subsequent analysis steps.

    # From my testing, anything lower than .94 can go.

    if sci_thresh is None:
        sci_thresh = BAD_SCI_THRESH
    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES

    bad_channels = set()
    
    # Quantify the quality of the coupling between the scalp and the optodes using the scalp coupling index
    sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od)

    # If a channel was less than the threshold, it will be marked as bad
    sci_bad = [ch for ch, s in zip(raw_od.ch_names, sci) if s < sci_thresh]
    if sci_bad:
        print("Channels that failed on SCI!!!:", sci_bad)

    # Update bad_channels (automatically handles duplicates via set)
    bad_channels.update(sci_bad)

    # Do we save the image?
    if save_images:
        sci_fig, ax = plt.subplots(figsize=(12, 4), layout="constrained")
        colors = [(0.0, 'red'), (sci_thresh, 'red'), ((sci_thresh+0.97)/2, 'yellow'), (0.97, 'green'), (1.0, 'green')]
        cmap = LinearSegmentedColormap.from_list('custom_red_yellow_green', colors)
        norm = plt.Normalize(vmin=0, vmax=1)
        scatter = ax.scatter(range(len(sci)), sci, c=sci, cmap=cmap, alpha=0.8, s=100, norm=norm)
        ax.set(xlabel="Channel Number", ylabel="Scalp Coupling Index", xlim=[0, len(sci)], ylim=[0, 1.02])
        ax.axhline(sci_thresh, color='black', linestyle='--', alpha=0.3, linewidth=1)
        cbar = sci_fig.colorbar(scatter, ax=ax, label="Scalp Coupling Index")
        cbar.set_ticks([0, sci_thresh, 0.97, 1])
        save_path = output_folder + "/4. Scalp-Coupling-Index for " + file_name + ".png"
        sci_fig.savefig(save_path)

    return bad_channels



def mark_bad_channels(raw_od, bad_channels_snr, bad_channels_sci, bad_channels_distance):
    '''Method to combine all bad channels and apply them to the original data.
    Input:\n
    raw_od (RawSNIRF) - The optical density data\n
    bad_channels_snr (set) - Set containing all of the bad channels that were marked by SNR\n
    bad_channels_sci (set) - Set containing all of the bad channels that were marked by SCI\n
    bad_channels_distance (set) - Set containing all of the bad channels that were marked by distance'''

    # Create a set and combine all of the bad channels
    bad_channels = set()
    bad_channels.update(bad_channels_snr)
    bad_channels.update(bad_channels_sci)

    if BAD_DISTANCE_CHANNELS:
        bad_channels.update(bad_channels_distance)

    # Update the bads key to contain all of the bad channels
    raw_od.info["bads"] = list(bad_channels)
    print("Channels that are marked bad!!:", raw_od.info["bads"])



def calculate_haemoglobin_concentration(raw_od, total_duration, file_name, ppf=None, output_folder=None, save_images=None):
    '''Method to calculate the haemoglobin concentration data.\n
    Input:\n
    raw_od (RawSNIRF) - The optical density data\n
    total_duration (float) - The total duration of the snirf data in seconds\n
    file_name (string) - The file name of the current file\n
    ppf (float) - (optional) The PPF value. Default is PPF\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    raw_haemo (RawSNIRF) - The haemoglobin concentration data'''

    if ppf is None:
        ppf = PPF
    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES
        
    # Get the haemoglobin concentration using beer lambert law
    raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf)

    # Do we save the image?
    if save_images:
        haemo_graph = raw_haemo.plot(n_channels=len(raw_haemo.ch_names), duration=total_duration)
        save_path = output_folder + "/5. Unfiltered Haemoglobin Concentration for " + file_name + ".png"
        haemo_graph.savefig(save_path)
        
    return raw_haemo



def calculate_filtered_haemoglobin(raw_haemo, file_name, low_freq=None, high_freq=None, output_folder=None, save_images=None):
    '''Method that takes the haemoglobin concentration and applies a high and low pass filter.\n
    Input:\n
    raw_haemo (RawSNIRF) - The haemoglobin concentration data
    file_name (string) - The file name of the current file\n
    low_freq (float) - (optional) The value where everything lower is considered noise and is filtered out. Default is FILTER_LOW_PASS\n
    high_freq (float) - (optional) The value where everything higher is considered noise and is filtered out. Default is FILTER_HIGH_PASS\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data'''


    if low_freq is None:
        low_freq = FILTER_LOW_PASS
    if high_freq is None:
        high_freq = FILTER_HIGH_PASS
    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES
        
    # Apply a filter to the data
    raw_haemo_unfiltered = raw_haemo.copy()
    raw_haemo_filtered = raw_haemo.filter(low_freq, high_freq, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02)

    # Do we save the image?
    if save_images:
        for i, (when, _raw) in enumerate(dict(Before=raw_haemo_unfiltered, After=raw_haemo).items()):
            fig = _raw.compute_psd().plot(average=True, amplitude=False, picks="data", exclude="bads")
            fig.suptitle(f"{when} filtering", weight="bold", size="x-large")
            save_path = output_folder + "/" + str(i+6) + ". " + when + " filtering data for " + file_name + ".png"
            fig.savefig(save_path)

    return raw_haemo_filtered



def calculate_annotations(raw_haemo_filtered, file_name, output_folder=None, save_images=None):
    '''Method that extract the annotations from the data.\n
    Input:\n
    raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
    file_name (string) - The file name of the current file\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    events (ndarray) - Array containing row number and what index the event is\n
    event_dict (dict) - Contains the names of the events'''

    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES
        
    # Get when the events occur and what they are called, and display a figure with the result
    events, event_dict = mne.events_from_annotations(raw_haemo_filtered)

    # Do we save the image?
    if save_images:
        fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo_filtered.info["sfreq"], show=False)
        save_path = output_folder + "/8. Annotations for " + file_name + ".png"
        fig.savefig(save_path)

    return events, event_dict



def calculate_good_epochs(raw_haemo_filtered, events, event_dict, file_name, tmin=None, tmax=None, reject_thresh=None, target_activity=None, target_control=None, output_folder=None, save_images=None):
    '''Calculates what epochs are good and creates a graph showing if any are dropped.\n
    Input:\n
    raw_haemo_filtered (RawSNIRF) - The filtered haemoglobin concentration data\n
    events (ndarray) - Array containing row number and what index the event is\n
    event_dict (dict) - Contains the names of the events\n
    file_name (string) - The file name of the current file\n
    tmin (float) - (optional) Time in seconds to display before the event. Default is TIME_MIN_THRESH\n
    tmax (float) - (optional) Time in seconds to display after the event. Default is TIME_MAX_THRESH\n
    reject_thresh (float) - (optional) Value that determines the threshold for rejecting epochs. Default is EPOCH_REJECT_CRITERIA_THRESH\n
    target_activity (string) - (optional) The target activity. Default is TARGET_ACTIVITY\n
    target_control (string) - (optional) The target control. Default is TARGET_CONTROL\n
    output_folder (string) - (optional) Where to save the images. Default is OUTPUT_FOLDER_LOCATION\n
    save_images (string) - (optional) Bool to save the images or not. Default is SAVE_IMAGES
    Output:\n
    good_epochs (Epochs) - The remaining good epochs\n
    all_epochs (Epochs) - All of the epochs'''

    if tmin is None:
        tmin = TIME_MIN_THRESH
    if tmax is None:
        tmax = TIME_MAX_THRESH
    if reject_thresh is None:
        reject_thresh = EPOCH_REJECT_CRITERIA_THRESH
    if target_activity is None:
        target_activity = TARGET_ACTIVITY
    if target_control is None:
        target_control = TARGET_CONTROL
    if output_folder is None:
        output_folder = OUTPUT_FOLDER_LOCATION
    if save_images is None:
        save_images = SAVE_IMAGES

    # Get all the good epochs
    good_epochs = mne.Epochs(
        raw_haemo_filtered,
        events,
        event_id=event_dict,
        tmin=tmin,
        tmax=tmax,
        reject=dict(hbo=reject_thresh),
        reject_by_annotation=True,
        proj=True,
        baseline=(None, 0),
        preload=True,
        detrend=None,
        verbose=True,
    )
    
    # Get all the epochs
    all_epochs = mne.Epochs(
        raw_haemo_filtered,
        events,
        event_id=event_dict,
        tmin=tmin,
        tmax=tmax,
        proj=True,
        baseline=(None, 0),
        preload=True,
        detrend=None,
        verbose=True,
    )
    
    if REJECT_PAIRS:
        # Calculate which epochs were in all but not in good
        all_idx = all_epochs.selection
        good_idx = good_epochs.selection
        bad_idx = np.setdiff1d(all_idx, good_idx)

        # Split the controls and the activities
        event_ids = all_epochs.events[:, 2]
        control_id = event_dict[target_control]
        activity_id = event_dict[target_activity]

        to_reject_extra = set()

        for i, idx in enumerate(all_idx):
            if idx in bad_idx:
                ev = event_ids[i]
                # If the control was bad, drop the following activity
                if ev == control_id and i + 1 < len(all_idx):
                    if event_ids[i + 1] == activity_id:
                        to_reject_extra.add(all_idx[i + 1])
                # If the activity was bad, drop the preceding activity
                if ev == activity_id and i - 1 >= 0:
                    if event_ids[i - 1] == control_id:
                        to_reject_extra.add(all_idx[i - 1])

        # Create a list to store all the new drops, only adding them if they are currently classified as good
        drop_idx_in_good = [
            np.where(good_idx == idx)[0][0] for idx in to_reject_extra if idx in good_idx
        ]

        # Drop the pairings of the bad epochs
        good_epochs.drop(drop_idx_in_good)

    # Do we save the image?
    if save_images:
        drop_log_fig = good_epochs.plot_drop_log(show=False)
        save_path = output_folder + "/8. Epoch drops for " + file_name + ".png"
        drop_log_fig.savefig(save_path)

    return good_epochs, all_epochs



def bad_check(raw_od, max_bad_channels=None):
    '''Method to see if we have more bad channels than our allowed threshold.\n
    Inputs:\n
    raw_od (RawSNIRF) - The optical density data\n
    max_bad_channels (int) - (optional) The max amount of bad channels we want to tolerate. Default is MAX_BAD_CHANNELS\n
    Output\n
    (bool) - True it we had more bad channels than the threshold, False if we did not'''

    if max_bad_channels is None:
        max_bad_channels = MAX_BAD_CHANNELS

    # Check if there is more bad channels in the bads key compared to the allowed amount
    if len(raw_od.info.get('bads', [])) >= max_bad_channels:
        return True
    else:
        return False
    
    

def remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs, epoch_pair_tolerance_window=None):
    '''Method to apply our new epochs to the loaded data in working memory. This is to ensure that the GLM does not see these epochs.
    Inputs:\n
    raw_haemo_filtered_minus_short (RawSNIRF) - The filtered haemoglobin concentration data\n
    good_epochs (Epochs) - The epochs we want the loaded data to take on\n
    epoch_pair_tolerance_window (int) - (optional) The amount of data points the paired epoch can deviate from the expected amount. Default is EPOCH_PAIR_TOLERANCE_WINDOW\n
    Output:\n
    raw_haemo_filtered_good_epochs (RawSNIRF) - The filtered haemoglobin concentration data with only the good epochs'''
    
    if epoch_pair_tolerance_window is None:
        epoch_pair_tolerance_window = EPOCH_PAIR_TOLERANCE_WINDOW
    # Copy the input haemoglobin concentration data and drop the bad channels
    raw_haemo_filtered_good_epochs = raw_haemo_filtered_minus_short.copy()
    raw_haemo_filtered_good_epochs = raw_haemo_filtered_good_epochs.drop_channels(raw_haemo_filtered_good_epochs.info['bads'])

    # Get the event IDs of the good events
    good_event_samples = set(good_epochs.events[:, 0])
    print(f"Total good events (epochs): {len(good_event_samples)}")

    # Get the current annotations
    raw_annots = raw_haemo_filtered_good_epochs.annotations

    # Create lists to use for processing
    clean_descriptions = []
    clean_onsets = []
    clean_durations = []
    dropped = []

    # Get the frequency of the file
    sfreq = raw_haemo_filtered_good_epochs.info['sfreq']

    for desc, onset, dur in zip(raw_annots.description, raw_annots.onset, raw_annots.duration):
        # Convert annotation onset time to sample index
        sample = int(onset * sfreq)

        if FORCE_DROP_ANNOTATIONS:
            for i in FORCE_DROP_ANNOTATIONS:
                if desc == i:
                    dropped.append((desc, onset))
                    continue
        
        # Check if the annotation is within the tolerance of any good event
        matched = any(abs(sample - event_sample) <= epoch_pair_tolerance_window for event_sample in good_event_samples)
        
        # We found a matching event
        if matched:
            clean_descriptions.append(desc)
            clean_onsets.append(onset)
            clean_durations.append(dur)
        else:
            dropped.append((desc, onset))

    # Create the new filtered annotations
    new_annots = Annotations(
        onset=clean_onsets,
        duration=clean_durations,
        description=clean_descriptions,
    )
    
    # Assign the new annotations
    raw_haemo_filtered_good_epochs.set_annotations(new_annots)

    # Print out the results
    print(f"Original annotations: {len(raw_annots)}")
    print(f"Kept annotations: {len(clean_descriptions)}")
    print("Kept annotation types:", set(clean_descriptions))
    if dropped:
        print(f"Dropped annotations: {len(dropped)}")
        print("Dropped annotations:")
        for desc, onset in dropped:
            print(f"  - {desc} at {onset:.2f}s")
    else:
        print("No annotations were dropped!")

    return raw_haemo_filtered_good_epochs



def individual_GLM_analysis(file_path, ID, stim_duration=5.0):
    '''Method to perform an individual GLM analysis on a participant.\n
    Inputs:\n
    file_path (string) - Path to the SNIRF file\n
    ID (string) - The filename of the SNIRF file\n
    stim_duration (float) - (optional) The stimulis duration in seconds. Default is 5.0\n
    Outputs:\n
    raw_haemo_minus_short (RawSNIRF) - The haemoglobin concentration data\n
    roi (DataFrame) - df containing region of interest results\n
    cha (DataFrame) - df containing channel results\n
    con (DataFrame) - df containing contrast results
    '''

    # Load the file, get the sources and detectors, update their position, and calculate the short channel and any large distance channels
    raw = load_snirf(file_path, FORCE_DROPPED_CHANNELS)

    # Did the user want to load new channel positions from an optode file?
    if OPTODE_FILE:
        raw_sources, raw_detectors = get_optode_coordinates()
        update_channel_location(raw, raw_detectors, raw_sources)

    # Calculate the short and long channels, if any, and get the file duration
    raw_minus_short, raw_only_short, bad_channel_dist, short_channel = calculate_short_and_long_channels(raw)
    total_duration = get_duration_of_snirf(raw)
    create_raw_graph(raw, total_duration, ID)

    # Calculate the optical density
    raw_od_minus_short = calculate_optical_density(raw_minus_short, total_duration, ID)
    
    # Calculate the SNR and SCI values and mark them along with the large distance channels bad
    # Did the user want SNR?
    if SNR:
        bad_channels_snr = calculate_snr(raw_od_minus_short, ID)
    else:
        bad_channels_snr = None

    # Did the user want SCI?
    if SCI:
        bad_channels_sci = calculate_sci(raw_od_minus_short, ID)
    else:
        bad_channels_sci = None

    # Mark all of the bad channels as bad. Bad distance channels are handled inside this method
    mark_bad_channels(raw_od_minus_short, bad_channels_snr, bad_channels_sci, bad_channel_dist)

    # TODO: Re-implement this. While we know that nobody included is bad on channels, this prevents the need for other files to pre check
    # Figure out if the participant is bad based on bad channels
    # is_bad = bad_check(raw_od_minus_short, MAX_BAD_CHANNELS)
    # if is_bad:
    #     print("Participant is bad on channels!")
    #     raise Exception("Participant is bad on channels!")

    # Calculate the haemoglobin concentration and filter it
    raw_haemo_minus_short = calculate_haemoglobin_concentration(raw_od_minus_short, total_duration, ID)

    # Determine if we apply the second filter or not
    if FILTER:
        raw_haemo_filtered_minus_short = calculate_filtered_haemoglobin(raw_haemo_minus_short, ID)
    else:
        raw_haemo_filtered_minus_short = raw_haemo_minus_short

    # Calculate the annotations in the file
    events, event_dict = calculate_annotations(raw_haemo_filtered_minus_short, ID)
    
    # Calculate all and only good epochs
    # TODO: Bring back using the good epochs I think?
    good_epochs, all_epochs = calculate_good_epochs(raw_haemo_filtered_minus_short, events, event_dict, ID)
    raw_haemo_filtered_good_epochs = remove_bad_epoch_pairings(raw_haemo_filtered_minus_short, good_epochs)

    # Create the design martix    
    design_matrix = make_first_level_design_matrix(
        raw_haemo_filtered_good_epochs,
        drift_model=DRIFT_MODEL,
        high_pass=1/(2*DURATION_BETWEEN_ACTIVITIES),
        hrf_model=HRF_MODEL,
        stim_dur=stim_duration,
    )

    if SHORT_CHANNEL_REGRESSION:
        # Append short channels mean to the design matrix
        # In theory these channels contain only systemic components, so including them in the design matrix allows us
        # to estimate the neural component related to each experimental condition uncontaminated by systemic effects.

        # Get all channel names
        available_channels = raw_only_short.ch_names

        # Find the correct short channel for this participant
        target_channels = [ch for ch in available_channels if ch.startswith(short_channel)]

        # Could not find the short channel!
        if not target_channels:
            raise ValueError("No matching short channel found!")

        # Apply the short channels to the design matrix.
        design_matrix["ShortHbO"] = np.mean(raw_only_short.copy().pick(picks=target_channels).get_data(), axis=0)
        design_matrix["ShortHbR"] = np.mean(raw_only_short.copy().pick(picks=target_channels).get_data(), axis=0)
        
        print("Design matrix columns:", design_matrix.columns.tolist())

    # Run the glm on the design matrix
    glm_est = run_glm(raw_haemo_filtered_minus_short, design_matrix, n_jobs=N_JOBS)

    # Add the regions of interest to the groups
    groups = dict(
        group_1_picks = picks_pair_to_idx(raw_haemo_filtered_minus_short, ROI_GROUP_1, on_missing="ignore"),
        group_2_picks = picks_pair_to_idx(raw_haemo_filtered_minus_short, ROI_GROUP_2, on_missing="ignore"),
    )

    # Extract the channel metrics
    cha = glm_est.to_dataframe()

    # Compute region of interest results from the channel data
    roi = glm_est.to_dataframe_region_of_interest(
        groups, design_matrix.columns, demographic_info=True
    )

    # Create the contrast matrix
    contrast_matrix = np.eye(design_matrix.shape[1])
    basic_conts = dict(
        [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)]
    )
    
    # Create and compute the contrast
    contrast_t = basic_conts[TARGET_ACTIVITY]
    contrast = glm_est.compute_contrast(contrast_t)
    con = contrast.to_dataframe()

    # Add the participant ID to the dataframes
    roi["ID"] = cha["ID"] = con["ID"] = ID

    # Convert to uM for nicer plotting below.
    cha["theta"] = [t * 1.0e6 for t in cha["theta"]]
    roi["theta"] = [t * 1.0e6 for t in roi["theta"]]
    con["effect"] = [t * 1.0e6 for t in con["effect"]]

    return raw_haemo_filtered_minus_short, roi, cha, con



def process_folder(folder_path, stim_duration):
    '''This method is the entrypoint for all of the functions. It loops over each file in the folder and processes them.\n
    Inputs:\n
    folder_path (str) - The folder's absolute location\n
    stim_duration (float) - The stimilus duration for the files in this folder\n
    Outputs:\n
    raw_haemo (RawSNIRF) - Contains all the snirf data for the last participant processed\n
    df_roi (DataFrame) - A dataframe containing all region of interest results\n
    df_cha (DataFrame) - A dataframe containing all channel level results\n
    df_con (DataFrame) - A dataframe containing all contrast results'''

    # Create dataframes to store the results of the participants
    df_roi = pd.DataFrame()
    df_cha = pd.DataFrame()
    df_con = pd.DataFrame()
    
    # Loop over the files in the folder
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)
        if os.path.isfile(file_path):
            try:
                # Run the full analysis
                raw_haemo, roi, channel, contrast = individual_GLM_analysis(file_path, file_name, stim_duration)

                # Add this participants results to the dataframe
                df_roi = pd.concat([df_roi, roi], ignore_index=True)
                df_cha = pd.concat([df_cha, channel], ignore_index=True)
                df_con = pd.concat([df_con, contrast], ignore_index=True)
            except Exception as e:
                print(f"Error processing {file_name}: {e}")
                continue

    return raw_haemo, df_roi, df_cha, df_con



def brain_3d_visualization(all_results, raw_haemo, t_or_theta='theta'):
    '''This method visualizes the t or theta values in 3d on a brain surface.\n
    Inputs:\n
    t_or_theta (string) - Whether to display the 't' values or the 'theta' values. Default is 'theta'\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n
    raw_haemo (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels and their positions. Could cause an issue with a weird number of drops'''
    
    # Need to drop the bad channels completely or plot_glm_surface_projection() will complain
    raw_hbo = raw_haemo.copy().pick(picks="hbo")
    raw_haemo.drop_channels(raw_hbo.info['bads'])
    
    # Determine if we are visualizing t or theta to set the appropriate limit
    if t_or_theta == 't':
        clim = dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
    elif t_or_theta == 'theta':
        clim = dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))

    # Loop over all groups
    for index, i in enumerate(all_results):

        # We only care for their channel results
        (_, df_cha, _) = all_results[i]

        # Get all activity conditions
        for cond in [TARGET_ACTIVITY]:

            # Filter for the condition and chromophore
            ch_summary = df_cha.query("Condition in [@cond] and Chroma == 'hbo'")
            
            # Determine number of unique participants based on their ID
            n_participants = ch_summary["ID"].nunique()

            # Use ordinary least squares (OLS) if only one participant
            if n_participants == 1:

                # t values
                if t_or_theta == 't':
                    ch_model = smf.ols("t ~ -1 + ch_name", ch_summary).fit()

                # theta values
                elif t_or_theta == 'theta':
                    ch_model = smf.ols("theta ~ -1 + ch_name", ch_summary).fit()

                print("OLS model is being used as there is only one participant!")

            # Use mixed effects model if there is multiple participants
            else:

                # t values
                if t_or_theta == 't':
                    ch_model = smf.mixedlm("t ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm")

                # theta values
                elif t_or_theta == 'theta':
                    ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit(method="nm")

            # Convert model results
            model_df = statsmodels_to_results(
                ch_model, order=raw_haemo.copy().pick("hbo").ch_names
            )

            # Plot brain figure
            brain = plot_glm_surface_projection(
                raw_haemo.copy().pick(picks="hbo"),
                model_df,
                view="dorsal",
                distance=BRAIN_DISTANCE,
                colorbar=True,
                clim=clim,
                mode=BRAIN_MODE,
                size=(800, 700),
            )
            
            brain.add_sensors(
                raw_haemo.info,
                trans="fsaverage",
                fnirs=["channels", "pairs", "sources", "detectors"],
            )
            

            # Read and parse the file
            positions = []
            with open(OPTODE_FILE_PATH, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line or ':' not in line:
                        continue  # skip empty/malformed lines
                    name, coords = line.split(':', 1)
                    coords = [float(x) for x in coords.strip().split()]
                    positions.append((name.strip(), coords))
                    
            for name, (x, y, z) in positions:
                # NOTE: Please check in with this!
                brain._renderer.text3d(x, y-30, z-40, name, color='red' if name.startswith('s') else 'blue', scale=0.002)
        
            # Set the display text for the brain image
            display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nGroup: ' + i + '\nCondition: '+ cond + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' 
                            + str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' 
                            + str(STIM_DURATION[index]) + 's\nLooking at: ' + t_or_theta + ' values') + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE
            
            # Apply the text onto the brain
            brain.add_text(0.12, 0.64, display_text, "title", font_size=11, color="k")



def brain_3d_contrast(raw_haemo, con_model_df, first_name, second_name, first_stim, second_stim, t_or_theta='theta'):
    '''This method creates a 3d visualization on a brain showing the contrasts between two groups.\n
    Inputs:\n
    raw_haemo (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels and their positions. Could cause an issue with a weird number of drops\n
    con_model_df (DataFrame) - Contains the contrast information for the two groups\n
    first_name (str) - The folder name of the first group\n
    second_name (str) - The folder name of the second group\n
    first_stim (float) - The stimulus duration of the first group\n
    second_stim (float) - The stimulus duration of the second group\n
    t_or_theta (string) - Whether to display the 't' values or the 'theta' values. Default is 'theta' '''

    # Get the list of channels in MNE and GLM
    mne_channels = raw_haemo.copy().pick(picks="hbo").ch_names
    glm_channels = con_model_df['ch_name'].values

    # Find common channels between them
    common_channels = list(set(mne_channels).intersection(glm_channels))

    # Filter the GLM DataFrame to only include common channels
    con_model_df_filtered = con_model_df[con_model_df['ch_name'].isin(common_channels)]

    # Ensure the order matches between the two
    con_model_df_filtered = con_model_df_filtered.set_index('ch_name').reindex(mne_channels).reset_index()
    
    if t_or_theta == 't':
        clim=dict(kind="value", pos_lims=(0, ABS_T_VALUE/2, ABS_T_VALUE))
    elif t_or_theta == 'theta':
        clim=dict(kind="value", pos_lims=(0, ABS_THETA_VALUE/2, ABS_THETA_VALUE))


    # Plot brain figure
    brain = plot_glm_surface_projection(
        raw_haemo.copy().pick(picks="hbo"),
        con_model_df_filtered,
        view="dorsal",
        distance=BRAIN_DISTANCE,
        colorbar=True,
        mode=BRAIN_MODE,
        clim=clim,
        size=(800, 700),
    )
    
    brain.add_sensors(
        raw_haemo.info,
        trans="fsaverage",
        fnirs=["channels", "pairs", "sources", "detectors"],
    )

    # Set the display text for the brain image
    display_text = ('Folder: ' + str(BASE_SNIRF_FOLDER) + '\nContrast: ' + first_name + ' - ' + second_name + '\nReject Criteria Threshold: ' + str(EPOCH_REJECT_CRITERIA_THRESH) + '\nMin Time Threshold: ' + 
                    str(TIME_MIN_THRESH) + 's\nMax Time Threshold: ' + str(TIME_MAX_THRESH) + 's\nShort Channel Regression: ' + str(SHORT_CHANNEL_REGRESSION) + '\nStim Duration: ' +  str(first_stim) + ', ' + 
                    str(second_stim) + '\nBrain Distance: ' + str(BRAIN_DISTANCE) + '\nBrain Mode: ' + BRAIN_MODE + '\nLooking at: ' + t_or_theta + ' values')
    
    # Apply the text onto the brain
    brain.add_text(0.12, 0.70, display_text, "title", font_size=11, color="k")
 


def plot_2d_3d_contrasts_between_groups(all_results, raw_haemo, t_or_theta='theta'):
    '''This method will plot both 2d and 3d representations of the contrasts between groups.\n
    Inputs:\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n
    raw_haemo (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n
    t_or_theta (string) - Whether to display the 't' values or the 'theta' values. Default is 'theta' '''
    
    # Dictionary to store data for each group
    group_dfs = {}

    # Filter the data to just hbo and remove the bad_channels
    raw_hbo = raw_haemo.copy().pick(picks="hbo")
    raw_haemo.drop_channels(raw_hbo.info['bads'])

    # Store all contrasts with the corresponding group name
    for group_name, (_, _, df_con) in all_results.items():
        group_dfs[group_name] = df_con
        group_dfs[group_name]["group"] = group_name

    # Concatenate all groups together
    df_combined = pd.concat(group_dfs.values(), ignore_index=True)

    # TODO: Is this redundant? We already are only picking hbo
    # Filter for only HbO
    con_summary = df_combined.query("Chroma in ['hbo']")

    # Fit the mixed-effects model
    model_formula = "effect ~ -1 + group:ch_name:Chroma"
    con_model = smf.mixedlm(
        model_formula, con_summary, groups=con_summary["ID"]
    ).fit(method="nm")

    # Get the t values if we are comparing them
    if t_or_theta == 't':
        t_values = con_model.tvalues

    # Get all the group names from the dictionary and how many groups we have
    group_names = list(group_dfs.keys())
    n_groups = len(group_names)

    # Store DataFrames for each contrast
    for i in range(n_groups):
        for j in range(i + 1, n_groups):
            group1_name = group_names[i]
            group2_name = group_names[j]
            
            if t_or_theta == 't':
                # Extract the t-values for both groups
                group1_vals = t_values.filter(like=f"group[{group1_name}]")
                group2_vals = t_values.filter(like=f"group[{group2_name}]")
                vlim_value = ABS_CONTRAST_T_VALUE

            elif t_or_theta == 'theta':
                # Extract the coefficients for both groups
                group1_vals = con_model.params.filter(like=f"group[{group1_name}]")
                group2_vals = con_model.params.filter(like=f"group[{group2_name}]")
                vlim_value = ABS_CONTRAST_THETA_VALUE

            # TODO: Does this work for all separators?
            # Extract channel names
            group1_channels = [
                name.split(":")[1].split("[")[1].split("]")[0]
                for name in group1_vals.index
            ]
            group2_channels = [
                name.split(":")[1].split("[")[1].split("]")[0]
                for name in group2_vals.index
            ]

            # Create the DataFrames with channel indices
            df_group1 = pd.DataFrame(
                {"Coef.": group1_vals.values}, index=group1_channels
            )
            df_group2 = pd.DataFrame(
                {"Coef.": group2_vals.values}, index=group2_channels
            )

            # Merge the two DataFrames on the channel names
            df_contrast = df_group1.join(df_group2, how="inner", lsuffix=f"_{group1_name}", rsuffix=f"_{group2_name}")

            # Compute the contrasts
            contrast_1_2 = df_contrast[f"Coef._{group1_name}"] - df_contrast[f"Coef._{group2_name}"]
            contrast_2_1 = df_contrast[f"Coef._{group2_name}"] - df_contrast[f"Coef._{group1_name}"]

            # Add the a-b / 1-2 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
            df_contrast["Coef."] = contrast_1_2
            con_model_df_1_2 = pd.DataFrame({
                "ch_name": df_contrast.index,
                "Coef.": df_contrast["Coef."],
                "Chroma": "hbo"
            })
            
            # Create the 3d visualization
            brain_3d_contrast(raw_haemo, con_model_df_1_2, group1_name, group2_name, STIM_DURATION[i], STIM_DURATION[j], t_or_theta)

            # Create the 2d visualization
            plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo"), con_model_df_1_2, names=True, res=128, vlim=(-vlim_value, vlim_value))
            
            # TODO: The title currently goes on the colorbar. Low priority
            plt.title(f"Contrast: {group1_name} vs {group2_name}")

            # Add the b-a / 2-1 contrast to the DataFrame. The order and names of the keys in the DataFrame are important!
            df_contrast["Coef."] = contrast_2_1
            con_model_df_2_1 = pd.DataFrame({
                "ch_name": df_contrast.index,
                "Coef.": df_contrast["Coef."],
                "Chroma": "hbo"
            })
            
            # Create the 3d visualization
            brain_3d_contrast(raw_haemo, con_model_df_2_1, group2_name, group1_name, STIM_DURATION[j], STIM_DURATION[i], t_or_theta)

            # Create the 2d visualization
            plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo"), con_model_df_2_1, names=True, res=128, vlim=(-vlim_value, vlim_value))
            
            # TODO: The title currently goes on the colorbar. Low priority
            plt.title(f"Contrast: {group2_name} vs {group1_name}")



def plot_2d_theta_graph(all_results):
    '''This method will create a 2d boxplot showing the theta values for each channel and group as independent ranges on the same graph.\n
    Inputs:\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n
    '''
    
    # Create a list to store the channel results of all groups    
    df_all_cha = []

    # Iterate over each group in all_results
    for group_name, (_, df_cha, _) in all_results.items():
        df_cha["group"] = group_name  # Add the group name to the data
        df_all_cha.append(df_cha)  # Append the dataframe to the list

    # Combine all the data into a single DataFrame
    df_all_cha = pd.concat(df_all_cha, ignore_index=True)

    # Filter for the taarget activity
    df_target = df_all_cha[df_all_cha["Condition"] == TARGET_ACTIVITY]

    # Get the number of unique groups to know how many colors are needed for the boxplot
    unique_groups = df_target["group"].nunique()
    palette = sns.color_palette("Set2", unique_groups)

    # Create the boxplot
    plt.figure(figsize=(15, 6))
    sns.boxplot(
        data=df_target,
        x="ch_name",
        y="theta",
        hue="group",
        palette=palette
    )

    # Format the boxplot
    plt.title("Theta Coefficients by Channel and Group")
    plt.xticks(rotation=90)
    plt.ylabel("Theta (ÂµM)")
    plt.xlabel("Channel")
    plt.legend(title="Group")
    plt.tight_layout()



def plot_individual_theta_averages(all_results):
    '''This method will create a catplot showing the theta vaules for each region of interest for each participant.\n
    Inputs:\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
    
    # Iterate over all the groups
    for group_name in all_results:
        
        # Store the region of interest data
        (df_roi, _, _) = all_results[group_name]
        
        # Filter the results down to what we want
        grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy()
        grp_results = grp_results.query("Chroma in ['hbo']").copy()

        # Rename the ROI's to be the friendly name
        roi_label_map = {
            "group_1_picks": ROI_GROUP_1_NAME,
            "group_2_picks": ROI_GROUP_2_NAME,
        }
        grp_results["ROI"] = grp_results["ROI"].replace(roi_label_map)

        # Create the catplot
        sns.catplot(
            x="Condition",
            y="theta",
            col="ID",
            hue="ROI",
            data=grp_results,
            col_wrap=5,
            errorbar=None,
            palette="muted",
            height=4,
            s=10,
        )



def plot_group_theta_averages(all_results):
    '''This method will create a stripplot showing the theta vaules for each region of interest for each group.\n
    Inputs:\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
    
    # Rename the ROI's to be the friendly name
    roi_label_map = {
        "group_1_picks": ROI_GROUP_1_NAME,
        "group_2_picks": ROI_GROUP_2_NAME,
    }

    # Setup subplot grid
    n = len(all_results)
    ncols = 2
    nrows = (n + 1) // ncols  # round up
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows), squeeze=False)

    # Iterate over all groups
    for index, (group_name, ax) in enumerate(zip(all_results, axes.flatten())):
        
        # Store the region of interest data
        (df_roi, _, _) = all_results[group_name]
        
        # Filter the results down to what we want
        grp_results = df_roi.query(f"Condition in ['{TARGET_ACTIVITY}', '{TARGET_CONTROL}']").copy()

        # Run a mixedlm model on the data
        roi_model = smf.mixedlm(
            "theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"]
        ).fit(method="nm")

        # Apply the new friendly names on to the data
        df = statsmodels_to_results(roi_model)
        df["ROI"] = df["ROI"].map(roi_label_map)

        # Create a stripplot:
        sns.stripplot(
            x="Condition",
            y="Coef.",
            hue="ROI",
            data=df.query("Chroma == 'hbo'"),
            dodge=False,
            jitter=False,
            size=5,
            palette="muted",
            ax=ax,
        )

        # Format the stripplot
        ax.set_title(f"Results for {group_name}")
        ax.legend(title="ROI", loc="upper right")

    # Remove any unused axes and apply final touches
    for j in range(index + 1, len(axes.flatten())):
        fig.delaxes(axes.flatten()[j])
    fig.tight_layout()
    fig.suptitle("Theta Averages Across Groups", fontsize=16, y=1.02)
    
    
    
def get_bad_src_det_pairs(raw):
    '''This method figures out the bad source and detector pairings for the 2d t+p graph to prevent them from being plotted.
    Inputs:\n
    raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channels\n
    Outputs:\n
    bad_pairs (set) - Set containing all of the bad pairs of sources and detectors'''
    
    # Create a set to store the bad pairs
    bad_pairs = set()
    
    # Iterate over all the channels in bads key
    for ch_name in raw.info['bads']:
        try:
            # Get all characters before the space
            parts = ch_name.split()[0]
            
            # Split with the separator
            src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
            src = int(src_str[1:])
            det = int(det_str[1:])
            
            # Add to the set
            bad_pairs.add((src, det))
            
        except Exception as e:
            print(f"Could not parse bad channel '{ch_name}': {e}")
            
    return bad_pairs



def compute_p_group_stats(df_cha, bad_pairs=set()):
    '''This method computes the proper p values for a group of t values using ttest_1samp.
    Inputs:\n
    df_cha (DataFrame) - DataFrame containg the groups channel data\n
    bad_pairs (set) - Set containing all of the bad pairs of sources and detectors\n
    Output:\n
    results (DataFrame) - DataFrame containing the source, detector, t, and p values'''
    
    # Filter the channel data down to what we want
    grouped = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector'])

    # Create an empty list to store the data for our result
    data = []
    
    # Iterate over the filtered channel data
    for (src, det), group in grouped:

        # If it is a bad channel pairing, do not process it
        if (src, det) in bad_pairs:
            print(f"Skipping bad channel Source {src} - Detector {det}")
            continue

        # Drop any missing values that could exist
        t_values = group['t'].dropna().values
        
        # Ensure that we still have our two t values, otherwise do not process this pairing
        if len(t_values) < 2:
            print(f"Skipping Source {src} - Detector {det}: not enough data (n={len(t_values)})")
            continue

        # Perform one-sample t-test on t-values across subjects
        _, pval = ttest_1samp(t_values, popmean=0)

        # Store all of the data for this ttest using the mean t-value for visualization
        data.append({
            'Source': src,
            'Detector': det,
            't': np.mean(t_values),
            'p_value': pval
        })

    # Create a DataFrame with the data and ensure it is not empty
    result = pd.DataFrame(data)
    if result.empty:
        print("No valid channel pairs with enough data for group-level testing.")

    return result



def plot_avg_activity_hbo_tvals(raw, all_results):
    '''This method plots the average t values for the groups on a 2D graph. p values less than or equal to P_THRESHOLD are solid lines, while greater p values are dashed lines.\n
    Inputs:\n
    raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
    
    # Iterate over all the groups
    for group_name in all_results:
        (_, df_cha, _) = all_results[group_name]
        
        num_tests = len(df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'").groupby(['Source', 'Detector']))
        print(f"Number of tests: {num_tests}")

        # Filter to only the activity data
        activity_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'")

        # Should only occur if the GLM was not ran previously
        if activity_df.empty:
            print("No data found.")
            return

        # Compute average t-value across individuals for each channel pairing
        bad_pairs = get_bad_src_det_pairs(raw)
        avg_df = compute_p_group_stats(df_cha, bad_pairs)

        print(f"Average t-values and p-values for {TARGET_ACTIVITY}:")
        for _, row in avg_df.iterrows():
            print(f"Source {row['Source']} <-> Detector {row['Detector']}: "
                f"Avg t-value = {row['t']:.3f}, Avg p-value = {row['p_value']:.3f}")

        # Extract the cource and detector positions from raw
        src_pos, det_pos = {}, {}
        for ch in raw.info['chs']:
            ch_name = ch['ch_name']
            if not ch_name or not ch['loc'].any():
                continue
            parts = ch_name.split()[0]
            src_str, det_str = parts.split(SOURCE_DETECTOR_SEPARATOR)
            src_num = int(src_str[1:])
            det_num = int(det_str[1:])
            src_pos[src_num] = ch['loc'][3:5]  # X, Y
            det_pos[det_num] = ch['loc'][6:8]

        # Set up the plot
        fig, ax = plt.subplots(figsize=(8, 6))

        # Plot the sources
        for pos in src_pos.values():
            ax.scatter(pos[0], pos[1], s=120, c='k', marker='o',
                    edgecolors='white', linewidths=1, zorder=3)

        # Plot the detectors
        for pos in det_pos.values():
            ax.scatter(pos[0], pos[1], s=120, c='k', marker='s',
                    edgecolors='white', linewidths=1, zorder=3)

        # Ensure that the colors stay within the boundaries even if they are over or under the max/min values
        norm = plt.Normalize(vmin=-ABS_P_T_GRAPH_VALUE, vmax=ABS_P_T_GRAPH_VALUE)
        cmap = plt.cm.seismic

        # Plot connections with avg t-values
        for _, row in avg_df.iterrows():
            src = row['Source']
            det = row['Detector']
            tval = row['t']
            pval = row['p_value']

            if src in src_pos and det in det_pos:
                x = [src_pos[src][0], det_pos[det][0]]
                y = [src_pos[src][1], det_pos[det][1]]
                style = '-' if pval <= P_THRESHOLD else '--'
                ax.plot(x, y, linestyle=style,
                        color=cmap(norm(tval)),
                        linewidth=4, alpha=0.9, zorder=2)

        # Format the Colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, shrink=0.85)
        cbar.set_label(f'Average {TARGET_ACTIVITY} t-value (hbo)', fontsize=11)

        # Formatting the subplots
        ax.set_aspect('equal')
        ax.set_title(f"Average t-values for {TARGET_ACTIVITY} (HbO) for {group_name}", fontsize=14)
        ax.set_xlabel('X position (m)', fontsize=11)
        ax.set_ylabel('Y position (m)', fontsize=11)
        ax.grid(True, alpha=0.3)

        # Set axis limits to be 1cm more than the optode positions
        all_x = [pos[0] for pos in src_pos.values()] + [pos[0] for pos in det_pos.values()]
        all_y = [pos[1] for pos in src_pos.values()] + [pos[1] for pos in det_pos.values()]
        ax.set_xlim(min(all_x)-0.01, max(all_x)+0.01)
        ax.set_ylim(min(all_y)-0.01, max(all_y)+0.01)

        plt.tight_layout()



def fold_channel(raw, all_results):
    '''This method uses the channel positions to relate responses to brain landmarks (tells you where on the brain the channel was located).\n
    Inputs:\n
    raw (RawSNIRF) - Contains all the snirf data for the last participant processed. Only used to get the channel locations.\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''

    # Locate the fOLD excel files
    mne.set_config('MNE_NIRS_FOLD_PATH', '~/mne_data/fOLD/fOLD-public-master/Supplementary')

    # Iterate over all of the groups
    for group_name in all_results:

        # List to store the results
        landmark_specificity_data = []

        # Filter the data to only what we want
        hbo_channels = raw.copy().pick(picks='hbo')
        hbo_channel_names = hbo_channels.ch_names
        
        # Format the output to make it slightly easier to read
        print("*" * 42)
        print(f'Landmark Specificity for {group_name}:')
        print("*" * 42)

        # Iterate over each channel
        for channel_name in hbo_channel_names:
            
            # Run the fOLD on the selected channel
            channel_data = raw.copy().pick(picks=channel_name)
            output = fold_channel_specificity(channel_data, interpolate=True, atlas='Brodmann')
    
            # Process each DataFrame that fold_channel_specificity returns
            for df_data in output:
                
                # Extract the relevant columns
                useful_data = df_data[['Landmark', 'Specificity']]
                
                # Store the results
                landmark_specificity_data.append({
                    'Channel': channel_name,
                    'Data': useful_data,
                })

                # Print the results
                for data in landmark_specificity_data:
                    print(f"Channel: {data['Channel']}")
                    print(data['Data'])
                    print("-" * 42)
                
                # Reset the list for the next particcipant
                landmark_specificity_data = []
                
                
                
def data_to_csv(all_results):
    '''This method stores the thetas (and other values) to a csv file for all participants, with a seperate file for each group.\n
    Inputs:\n
    all_results (dict) - Contains the df_roi, df_cha, and df_con for each group\n'''
    
    # Iterate over all groups
    for group_name in all_results:
        
        # Get the channel data and generate the file name
        (_, df_cha, _) = all_results[group_name]
        filename = group_name + '.csv'
        
        # Filter to just the target condition and store it in the csv
        output_df = df_cha.query(f"Condition == '{TARGET_ACTIVITY}' and Chroma == 'hbo'")
        output_df.to_csv(filename)
        
        
        
def brain_landmarks_3d(raw_haemo):
    from mne.viz import Brain

    # Create an empty brain plot
    brain = Brain("fsaverage", background="white", size=(800, 700))

    # Add optode text labels manually
    positions = []
    with open(OPTODE_FILE_PATH, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or ':' not in line:
                continue
            name, coords = line.split(':', 1)
            coords = [float(x) for x in coords.strip().split()]
            positions.append((name.strip(), coords))

    for name, (x, y, z) in positions:
        brain._renderer.text3d(
            x, y-30, z-40, name,
            color='red' if name.startswith('s') else 'blue',
            scale=0.002
        )

    brain.add_sensors(
        raw_haemo.info,
        trans="fsaverage",
        fnirs=["channels", "pairs", "sources", "detectors"],
    )

    # Add Brodmann labels
    labels = mne.read_labels_from_annot("fsaverage", "PALS_B12_Brodmann", "rh")

    label_colors = {
        "Brodmann.39-rh": "blue",
        "Brodmann.40-rh": "green",
        "Brodmann.6-rh": "pink",
        "Brodmann.7-rh": "orange",
        "Brodmann.17-rh": "red",
        "Brodmann.1-rh": "yellow",
        "Brodmann.2-rh": "yellow",
        "Brodmann.3-rh": "yellow",
        "Brodmann.18-rh": "red",
        "Brodmann.19-rh": "red",
        "Brodmann.4-rh": "purple",
        "Brodmann.8-rh": "white"
    }

    for label in labels:
        if label.name in label_colors:
            brain.add_label(label, borders=False, color=label_colors[label.name])
        
        
        
def nope(all_results, raw_haemo):
    
# mne.datasets.fetch_hcp_mmp_parcellation(accept=True)

# labels = mne.read_labels_from_annot(
#     "fsaverage", "HCPMMP1", "lh"
# )
# labels_combined = mne.read_labels_from_annot(
#     "fsaverage", "HCPMMP1_combined", "lh"
# )

# '''     ``'lateral'``
#             From the left or right side such that the lateral (outside)
#             surface of the given hemisphere is visible.
#         ``'medial'``
#             From the left or right side such that the medial (inside)
#             surface of the given hemisphere is visible (at least when in split
#             or single-hemi mode).
#         ``'rostral'``
#             From the front.
#         ``'caudal'``
#             From the rear.
#         ``'dorsal'``
#             From above, with the front of the brain pointing up.
#         ``'ventral'``
#             From below, with the front of the brain pointing up.
#         ``'frontal'``
#             From the front and slightly lateral, with the brain slightly
#             tilted forward (yielding a view from slightly above).
#         ``'parietal'``
#             From the rear and slightly lateral, with the brain slightly tilted
#             backward (yielding a view from slightly above).
#         ``'axial'``
#             From above with the brain pointing up (same as ``'dorsal'``).
#         ``'sagittal'``
#             From the right side.
#         ``'coronal'``
#             From the rear.'''

# view_map = {
#     "left-lat": np.r_[np.arange(1, 4), 13],
#     "caudal": np.r_[np.arange(1, 5), 13],
#     "right-lat": np.r_[np.arange(6, 24)],
# }

# fig_montage = mne_nirs.visualisation.plot_3d_montage(
#     raw_haemo.info, view_map=view_map)

# fig_montage.savefig("montage.png")
    return True