Source code for spikewrap.configs.config_utils

import glob
import itertools
import os
import shutil
from pathlib import Path

import yaml

from spikewrap.process import _preprocessing
from spikewrap.utils import _utils


def get_configs(name: str) -> tuple[dict, dict]:
    """
    Loads a config yaml file from the default config path.

    Parameters
    ----------
    name: name of the configs to load.
          Should not include the .yaml suffix.

    Returns
    -------

    pp_steps
        a dictionary containing the preprocessing
       step order (keys) and a [pp_name, kwargs]
       list containing the spikeinterface preprocessing
       step and keyword options.

    sorter_options
        a dictionary with sorter name (key) and
        a dictionary of kwargs to pass to the
        spikeinterface sorter class.
    """
    config_dir = get_configs_path()

    available_files = glob.glob((config_dir / "*.yaml").as_posix())
    available_files = [Path(path_).stem for path_ in available_files]

    if name not in available_files:
        # then assume it is a full path

        assert Path(name).is_file(), (
            f"{name} is neither the name of an existing "
            f"config or valid path to configuration file."
        )

        assert Path(name).suffix in [
            ".yaml",
            ".yml",
        ], f"{name} is not the path to a .yaml file"

        config_filepath = Path(name)

    else:
        config_filepath = config_dir / f"{name}.yaml"

    with open(config_filepath) as file:
        config = yaml.full_load(file)

    pp_steps = config.get("preprocessing", {})
    sorting = config.get("sorting", {})

    return pp_steps, sorting


[docs] def get_configs_path() -> Path: """ Get the path to the User home directory folder in which all spikewrap config yamls are stored. Returns ------- Path The path to the spikewrap `configs` directory. """ configs_path = Path.home() / ".spikewrap" / "configs" if not configs_path.is_dir(): _create_user_configs_folder(configs_path) return configs_path
def _create_user_configs_folder(configs_path: Path) -> None: """ Create the spikewrap configs path where config YAML files are stored. Copy the YAMLs from the spikewrap install directory (we do not want to manage files directly in the installation directory, due to potential permissions issues). Once this folder is set up, all config YAMLs are managed in the user directory. """ configs_path.mkdir(parents=True) default_configs_path = ( Path(os.path.dirname(os.path.realpath(__file__))) / "_backend" / "_default_configs" ) for config_filepath in list( default_configs_path.glob("*.yaml") ): # TODO: store canon suffix shutil.copy(config_filepath, configs_path)
[docs] def show_available_configs() -> None: """ Print the file names of all YAML config files in the user config path. """ configs_path = get_configs_path() yaml_paths = itertools.chain( configs_path.glob("*.yaml"), configs_path.glob("*.yml") ) yaml_names = [path_.name for path_ in yaml_paths] _utils.message_user(f"The available configs are:\n" f"{yaml_names}")
[docs] def save_config_dict(config_dict: dict, name: str, folder: Path | None = None): """ Save a configuration dictionary to a YAML file. Parameters ---------- config_dict The configs dictionary to save. name The name of the YAML file (with or without the `.yaml` extension). folder If None (default), the config is saved in the spikewrap-managed user configs folder. Otherwise, save in `folder`. """ if folder is None: folder = get_configs_path() output_filepath = Path(folder) / name if not output_filepath.suffix: output_filepath = output_filepath.with_suffix(".yaml") # use canonical _utils._dump_dict_to_yaml(output_filepath, config_dict)
[docs] def load_config_dict(filepath: Path) -> dict: """ Load a configuration dictionary from a YAML file. Parameters ---------- filepath The full path to the YAML file, including the file name and extension. Returns ------- dict The configs dict loaded from the YAML file. """ if not filepath.is_file(): raise FileNotFoundError(f"No file found at {filepath}.") if filepath.suffix not in [".yml", ".yaml"]: # TODO: centralise raise ValueError( f"File {filepath.name} is not a yaml file, must end in .yml or .yaml" ) return _utils._load_dict_from_yaml(filepath)
[docs] def show_configs(name: str) -> None: """ Print the configuration options. """ pp_steps, sorting = get_configs(name) _utils.show_preprocessing_configs(pp_steps) _utils.show_sorting_configs(sorting)
[docs] def show_supported_preprocessing_steps() -> None: """ Print the (currently supported) SpikeInterface preprocessing steps. """ pp_steps = _preprocessing._get_pp_funcs() _utils.message_user( f"Currently supported preprocessing steps are:\n" f"{list(pp_steps.keys())}" )