from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import yaml
if TYPE_CHECKING:
import matplotlib
import submitit
from probeinterface import Probe
import time
from pathlib import Path
import numpy as np
import spikeinterface.full as si
from spikewrap.configs import config_utils
from spikewrap.configs._backend import canon
from spikewrap.process import _loading
from spikewrap.structure._preprocess_run import PreprocessedRun
from spikewrap.structure._raw_run import (
ConcatRawRun,
SeparateRawRun,
)
from spikewrap.structure._sorting_run import (
ConcatSortingRun,
SeparateSortingRun,
)
from spikewrap.utils import _utils
[docs]
class Session:
"""
Represents an electrophysiological recording session, consisting of a single or multiple runs.
Exposes functions for preproecssing and sorting data from the session.
Parameters
----------
subject_path
The path to the subject's directory. This should contain the ``session_name`` directory.
session_name
The name of this session. Must match the session folder name in the `subject_path`.
file_format
Acquisition software used for recording, either ``"spikeglx"`` or ``"openephys"``.
Determines how a session's runs are discovered.
run_names
Specifies which runs within the session to include. If ``"all"`` (default), includes all
runs detected within the session. Otherwise, a ``list of str``, a list of specific run names.
Each name must correspond to a run folder within the session. Order passed will be the concentration order.
output_path
The path where preprocessed data will be saved (in NeuroBlueprint style).
probe
A ProbeInterface probe object to set on the recordings. If `None`,
auto-detection of probe is attempted.
Notes
-----
The responsibility of this class is to manage the processing of runs
contained within the session. Raw data are held in ``self._raw_runs``, a list of
``SeparateRawRun``. The sync channel (if available) is held on the raw run in a
recording object (other channels unused) and possibly mutated.
``load_raw_runs()`` will refresh the run sync channels.
self._pp_runs is filled by self.preprocess(). Copies of _raw_runs are concatenated
and / or split-by-shank before preprocessing, which fills self._pp_runs.
The sync channel is attached in-memory (no longer on a recording).
self._sorting_runs contains the sorting based on self._pp_runs. Again, copies
of preprocessing runs are concatenated and / or split by shank before sorting.
Properties of this class should never be modified from outside the class.
"""
def __init__(
self,
subject_path: Path | str,
session_name: str,
file_format: Literal["spikeglx", "openephys"],
run_names: Literal["all"] | list[str] = "all",
output_path: Path | None = None,
probe: Probe | None = None,
):
parent_input_path = Path(subject_path)
self._check_input_path(parent_input_path)
self._check_file_format(file_format)
if isinstance(run_names, str) and not run_names == "all":
run_names = [run_names]
self._passed_run_names = run_names
self._file_format = file_format
self._probe = probe
self._parent_input_path = parent_input_path
self._ses_name = session_name
self._output_path = (
Path(output_path) if output_path else self._output_from_parent_input_path()
)
self._running_slurm_jobs: list[submitit.Job] = []
self._raw_runs: list[SeparateRawRun] = []
self._pp_runs: list[PreprocessedRun] = []
self._sorting_runs: list[SeparateSortingRun | ConcatSortingRun] = []
# ---------------------------------------------------------------------------
# Public Functions
# ---------------------------------------------------------------------------
[docs]
def preprocess(
self,
configs: dict | str | Path,
concat_runs: bool = False,
per_shank: bool = False,
) -> None:
"""
Preprocess recordings for all runs for this session.
This step is lazy, under the hood running the preprocessing steps from ``configs``
on SpikeInterface recording objects. Preprocessing of data is performed on the
fly when required (e.g. plotting, saving or sorting).
Parameters
----------
configs
- If a ``str`` is provided, expects the name of a stored configuration file.
See ``show_available_configs()`` and ``save_config_dict()`` for details.
- If a ``Path`` is provided, expects the path to a valid spikewrap config YAML file.
- A spikewrap configs dictionary, either including the ``"preprocessing"`` level
or the ``"preprocessing"`` level itself. See documentation for details.
concat_runs
If ``True``, all runs will be concatenated together before preprocessing.
Use ``session.get_raw_run_names()`` to check the order of concatenation.
per_shank
If ``True``, perform preprocessing on each shank separately.
overwrite_in_memory
If ``False`` (default), the
"""
pp_steps = self._infer_steps_from_configs_argument(configs, "preprocessing")
_utils.show_preprocessing_configs(pp_steps)
if not any(self._raw_runs):
self._load_raw_data(internal_overwrite=False)
runs_to_preprocess: list[SeparateRawRun | ConcatRawRun]
if concat_runs:
runs_to_preprocess = [self._get_concat_raw_run()]
else:
runs_to_preprocess = self._raw_runs # type: ignore
self._pp_runs = []
for run in runs_to_preprocess:
preprocessed_run = run.preprocess(pp_steps, per_shank)
orig_run_names = run._orig_run_names
self._pp_runs.append(
PreprocessedRun(
raw_data_path=run._parent_input_path,
ses_name=self._ses_name,
run_name=run._run_name,
file_format=run._file_format,
session_output_path=self._output_path,
preprocessed_data=preprocessed_run,
pp_steps=pp_steps,
orig_run_names=orig_run_names,
)
)
[docs]
def save_preprocessed(
self,
overwrite: bool = False,
chunk_duration_s: float = 2,
n_jobs: int = 1,
slurm: dict | bool = False,
) -> None:
"""
Save preprocessed data for all runs in the current session.
This method iterates over each run in the session and invokes the `save_preprocessed` method
to persist the preprocessed data. It supports options to overwrite existing data, specify
chunk sizes for data saving, utilize parallel processing, and integrate with SLURM for
job scheduling.
Parameters
----------
overwrite
If `True`, existing preprocessed run data will be overwritten.
Otherwise, an error will be raised.
chunk_duration_s
Size of chunks which are separately to preprocessed and written.
n_jobs
Number of parallel jobs to run for saving preprocessed data.
Sets SpikeInterface's `set_global_job_kwargs`.
slurm
Configuration for submitting the save jobs to a SLURM workload manager.
If `False` (default), jobs will be run locally. If `True`, job will be run in SLURM
with default arguments. If a `dict` is provided, it should contain SLURM arguments.
See `tutorials` in the documentation for details.
"""
for run in self._pp_runs:
job_if_slurm = run.save_preprocessed(
overwrite, chunk_duration_s, n_jobs, slurm
)
if slurm:
self._running_slurm_jobs.append(job_if_slurm)
[docs]
def plot_preprocessed(
self,
run_idx: Literal["all"] | int = "all",
mode: Literal["map", "line"] = "map",
time_range: tuple[float, float] = (0.0, 1.0),
show_channel_ids: bool = True,
show: bool = False,
figsize: tuple[int, int] = (10, 6),
) -> dict[str, matplotlib.Figure]:
"""
Plot preprocessed run data.
One plot will be generated for each run. For preprocessing
per-shank, each shank will appear in a subplot. Under the hood,
calls SpikeInterface's ``plot_traces()`` function.
Parameters
----------
run_idx
- If ``"all"``, plots preprocessed data for all runs in the session.
- If an integer, plots preprocessed data for the run at the specified index in ``self._pp_runs``.
mode
Determines the plotting style, a heatmap-style or line plot.
time_range
Time range (start, end), in seconds, to plot. e.g. (0.0, 1.0)
show_channel_ids
If ``True``, displays the channel identifiers on the plots.
show
If ``True``, displays the plots immediately. If ``False``, the
plots are generated and returned without being displayed.
figsize
Specifies the size of the figure in inches as ``(width, height)``.
Returns
-------
dict of {str: matplotlib.Figure}
A dictionary mapping each run's name to its corresponding ``matplotlib.Figure`` object.
"""
time_range = np.array(time_range, dtype=np.float64)
all_runs = self._pp_runs if run_idx == "all" else [self._pp_runs[run_idx]]
all_figs = {}
for run in all_runs:
fig = run.plot_preprocessed(
mode=mode,
time_range=time_range,
show_channel_ids=show_channel_ids,
show=show,
figsize=figsize,
)
all_figs[run._run_name] = fig
return all_figs
[docs]
def sort(
self,
configs: str | dict,
run_sorter_method: str = "local",
per_shank: bool = False,
concat_runs: bool = False,
overwrite: bool = True,
slurm: bool = False,
):
"""
Sort the preprocessed recordings (self.get_preprocessed_run_names())
and save the output to disk.
Parameters
----------
configs
- If a ``str`` is provided, expects the name of a stored configuration file.
See ``show_available_configs()`` and ``save_config_dict()`` for details.
- If a ``Path`` is provided, expects the path to a valid spikewrap config YAML file.
- A spikewrap configs dictionary, either including the ``"sorting"`` level
or the ``"sorting"`` level itself. Only 1 sorter at a time currently supported.
run_sorter_method
- if "local", use the current python environment to run sorting (e.g. for `kilosort4`)
- if a path (either string or Path), must be the path to a matlab repository
used to run the sorter (e.g. kilosort 2.5 repository).
- if "docker" or "singularity", docker or singularity will be used to run the sorting.
singularity images are stored in a folder at the level of "rawdata" and "derivatives".
per_shank
If `True`, preprocessed recordings will be split by shank. If preprocessed recordings
were already split by shank for preprocessing, this should be `False`.
concat_runs
If `True`, preprocessed runs will be concatenated togerher before sorting. If runs were
already concatenated before preprocessing, this should be `False`.
overwrite
If `True`, existing outputs will be overwritten. Otherwise, an error will be raised
if existing outputs are found.
slurm
Configuration for submitting the save jobs to a SLURM workload manager.
If `False` (default), jobs will be run locally. If `True`, job will be run in SLURM
with default arguments. If a `dict` is provided, it should contain SLURM arguments.
See `tutorials` in the documentation for details.
"""
pp_runs: list[PreprocessedRun]
if not any(self._pp_runs):
pp_runs = self._load_pp_runs_from_disk()
else:
pp_runs = self._pp_runs
sorting_configs = self._infer_steps_from_configs_argument(configs, "sorting")
self._sorting_runs = []
if concat_runs:
if len(pp_runs) == 1:
raise ValueError(
f"`concat_runs=True` but there is only one preprocessed run: {pp_runs[0]._run_name}"
)
else:
self._sorting_runs = [ConcatSortingRun(pp_runs, self._output_path)]
else:
self._sorting_runs = [
SeparateSortingRun(pp_run, self._output_path) for pp_run in pp_runs
]
for run in self._sorting_runs:
job_if_slurm = run.sort(
sorting_configs, run_sorter_method, per_shank, overwrite, slurm
)
if slurm:
self._running_slurm_jobs.append(job_if_slurm)
def _load_pp_runs_from_disk(self) -> list[PreprocessedRun]:
"""
Load a list of PreprocessedRun, that have been previously saved, from disk.
This is used to allow sorting from previously saved data, allowing `session.preprocess()`
to be skipped. These will be loaded in the order specified in the run names passed to
this class on instantiation. If this is "all" they will be loaded from disk in glob order.
TODO
----
It would be nice to load these into self._pp_runs so that they can be re-visualised.
TODO
----
This badly needs a tidy up
"""
if self._passed_run_names == "all":
# Find all folderes with "preprocessed" folder inside,
# assuming this is an already-preprocessed run.
run_folders = list(self._output_path.glob("*"))
passed_run_names = []
for run_path in run_folders:
if run_path.is_dir() and any(run_path.glob("preprocessed")):
passed_run_names.append(run_path.name)
else:
passed_run_names = self._passed_run_names
if "concat_run" in passed_run_names and len(passed_run_names) != 1:
raise ValueError(
"Cannot load `concat_run` alongside separate runs. Specify "
"the exact runs to sort with the `run_names` argument of `Session`."
)
_utils.message_user(
"Preprocessed runs were loaded from disk, in the order:"
f"{passed_run_names}"
)
pp_runs = []
for run_name in passed_run_names:
# For each run, load the stored spikewrap info.
run_folder = self._output_path / run_name
file_path = run_folder / "preprocessed" / canon.spikewrap_info_filename()
if not file_path.is_file():
raise FileNotFoundError(
f"No saved preprocessed data found at: {file_path}"
)
with file_path.open("r") as file:
run_info = yaml.safe_load(file)
# Based on the number of saved shanks, load the recording objects
# into a dictionary (keys are shank_id)
preprocessed_shanks = {}
for shank_id in run_info["shank_ids"]:
if shank_id == "grouped":
recording_path = run_folder / "preprocessed"
else:
recording_path = run_folder / "preprocessed" / shank_id
if not recording_path.is_dir():
raise FileNotFoundError(
f"No preprocessed data found at: {recording_path}"
)
recording = si.load_extractor(recording_path)
prepro_dict = {
run_info["prepro_key"]: recording
} # TODO: check naming of this and make sure its consistent
preprocessed_shanks[shank_id] = prepro_dict
# Instantiate the class based on the above parameters
run = PreprocessedRun(
Path(run_info["raw_data_path"]),
run_info["ses_name"],
run_info["run_name"],
run_info["file_format"],
Path(run_info["session_output_path"]),
preprocessed_shanks,
pp_steps=run_info["pp_steps"],
orig_run_names=run_info["orig_run_names"],
)
pp_runs.append(run)
return pp_runs
# Getters -----------------------------------------------------------------
[docs]
def wait_for_slurm(self):
"""
Run a while loop with pause until
all slurm jobs are complete.
"""
while True:
self._running_slurm_jobs = [
job for job in self._running_slurm_jobs if not job.done()
]
if not any(self._running_slurm_jobs):
break
time.sleep(5)
# Getters -----------------------------------------------------------------
[docs]
def get_raw_run_names(self) -> list[str]:
"""
Return a list of run names for the raw data.
Their order is the order in which any concatenation
will be performed.
"""
return [run._run_name for run in self._raw_runs]
[docs]
def get_preprocessed_run_names(self) -> list[str]:
"""
Return a list of names of the preprocessed data.
If data was concatenated, the run name will be "concat_run".
If not concatenated, their order is the order
concatenation will take place prior to sorting
(if ``concat_run=True``).
"""
return [run._run_name for run in self._pp_runs]
[docs]
def get_output_path(self) -> Path:
"""
The path where processed data will be output for this session.
"""
return self._output_path
[docs]
def load_raw_data(self, overwrite: bool = False) -> None:
"""
Load raw data, to allow editing of the sync channel.
Note this will overwrite any previous editing of the
sync channel.
"""
if any(self._raw_runs) and not overwrite:
raise RuntimeError(
"Runs have already been loaded for this session. "
"Use `overwrite=True` to load again."
)
self._load_raw_data(internal_overwrite=overwrite)
[docs]
def get_sync_channel(self, run_idx: int):
"""
Return the sync channel in a numpy array. Currently only
Neuropixels (sync channel is 385th channel) supported.
Parameters
__________
run_idx
Index of the run to get the sync channel from,
as ordered by ``self.get_raw_run_names()``.
"""
self._assert_sync_channel_checks()
return self._raw_runs[run_idx].get_sync_channel()
[docs]
def plot_sync_channel(
self, run_idx: int, show: bool = True
) -> list[matplotlib.lines.Line2D]:
"""
Plot the sync channel for the run.
Parameters
----------
run_idx
Index of the run to get the sync channel from,
as ordered by ``self.get_raw_run_names()``.
show
If ``True``, plt.show() is called.
"""
self._assert_sync_channel_checks()
return self._raw_runs[run_idx].plot_sync_channel(show)
[docs]
def silence_sync_channel(
self, run_idx: int, periods_to_silence: list[tuple]
) -> None:
"""
Set periods on the sync channel to zero.
Parameters
----------
run_idx
Index of the run to get the sync channel from,
as ordered by ``self.get_raw_run_names()``.
periods_to_silence
A list of 2-tuples, where each entry in the tuples are
the (start, stop) sample to silence. For example,
[(0, 10), (50, 500)] will set the samples 0 - 10 and
50 - 500 to zero.
"""
self._assert_sync_channel_checks()
self._raw_runs[run_idx].silence_sync_channel(periods_to_silence)
[docs]
def save_sync_channel(
self, overwrite: bool = False, slurm: dict | bool = False
) -> None:
"""
Save all loaded runs sync channel to disk.
"""
if not self.raw_runs_loaded():
self._load_raw_data(internal_overwrite=False)
for run in self._raw_runs:
job_if_slurm = run.save_sync_channel(overwrite, slurm)
if slurm:
self._running_slurm_jobs.append(job_if_slurm)
[docs]
def raw_runs_loaded(self):
return all(run._raw is not None for run in self._raw_runs)
def _assert_sync_channel_checks(self):
""" """
if not any(self._raw_runs):
raise RuntimeError(
"Cannot work with sync channels until raw data loaded. "
"Use `session.load_raw_data()`."
)
if any(self._pp_runs) or any(self._sorting_runs):
raise RuntimeError(
"Cannot work with the sync channel after preprocessing. "
"Sync channel must be handled prior to preprocessing. "
"Instantiate a new Session object to begin a new workflow. "
"If this is annoying please contact spikewrap."
)
# ---------------------------------------------------------------------------
# Private Functions
# ---------------------------------------------------------------------------
# Manage `Runs` -------------------------------------------------------------
def _load_raw_data(self, internal_overwrite: bool = False) -> None:
"""
Fill self._raw_runs with a list of `SeparateRawRun` objects,
each holding data for the run.
Parameters
----------
internal_overwrite
Safety flag to ensure overwriting existing runs is intended.
Notes
-----
The session may include a `ephys` folder (NeuroBlueprint) but currently
non-NeuroBlueprint (no `ephys` folder) is included. `session_path` is the
path to the `ses-xxx` folder but `get_raw_run_paths` will detect and
include the `ephys` folder if necessary.
"""
if self._raw_runs and not internal_overwrite:
raise RuntimeError(f"Cannot overwrite _runs for session {self._ses_name}")
session_path = self._parent_input_path / self._ses_name
run_paths = _loading.get_raw_run_paths(
self._file_format,
session_path,
self._passed_run_names,
)
runs: list[SeparateRawRun] = []
for run_path in run_paths:
separate_run = SeparateRawRun(
parent_input_path=run_path.parent,
parent_ses_name=self._ses_name,
run_name=run_path.name,
file_format=self._file_format,
probe=self._probe,
sync_output_path=self._output_path / run_path.name,
)
separate_run.load_raw_data()
runs.append(separate_run)
self._raw_runs = runs
def _get_concat_raw_run(self) -> ConcatRawRun:
"""
Concatenate multiple separate runs into a single consolidated `ConcatRawRun`.
`SeparateRawRun` and `ConcatRawRun` both expose processing functionality
can be substituted for one another in this context.
"""
if len(self._raw_runs) == 1:
raise RuntimeError("Cannot concatenate runs, only one run found.")
assert all(
[isinstance(run, SeparateRawRun) for run in self._raw_runs]
), "All runs must be type `SeparateRawRun` for `ConcatRawRun"
_utils.message_user(
f"Concatenating raw recordings in the following order:"
f"{self.get_raw_run_names()}"
)
return ConcatRawRun(
self._raw_runs,
self._parent_input_path,
self._ses_name,
self._file_format,
)
# Path Resolving -----------------------------------------------------------
def _output_from_parent_input_path(self) -> Path:
"""
Infer the output path for the processed session data given the input path.
Assumes NeuroBlueprint style, in which the output path root is
a `derivatives` folder in the same location of the `rawdata` folder.
If this is not a NeuroBlueprint-organised dataset, raise ValueError and
require explicit passing of `self._output_path`.
Returns
-------
Path
The inferred NeuroBlueprint session output path.
"""
sub_path, sub_name = self._resolve_subject_input_path()
rawdata_path = sub_path.parent
if rawdata_path.name != "rawdata":
raise ValueError(
f"Cannot infer `output_path` from non-NeuroBlueprint "
f"folder structure (expected 'rawdata'->subject->session "
f"in path {self._parent_input_path}. "
f"Pass the session output folder explicitly as `output_path`."
)
return rawdata_path.parent / "derivatives" / sub_name / self._ses_name / "ephys"
# Checkers ------------------------------------------------------------------
def _resolve_subject_input_path(self) -> tuple[Path, str]:
"""
Return the path to the subject folder, and the subject name.
Assumes the subject path is immediately above the session path.
This is true whether in NeuroBlueprint or other accepted
folder formats found in the documentation.
"""
sub_path = self._parent_input_path
sub_name = sub_path.name
return sub_path, sub_name
@staticmethod
def _check_input_path(parent_input_path: Path) -> None:
"""
Ensure the path to the parent (subject) folder exists.
"""
if not parent_input_path.is_dir():
raise FileNotFoundError(f"{parent_input_path} is not a directory.")
@staticmethod
def _check_file_format(file_format: Literal["spikeglx", "openephys"]) -> None:
"""
Ensure the `file_format` is a permitted value.
"""
supported_formats = ["spikeglx", "openephys"]
if file_format not in supported_formats:
raise ValueError(
f"`file_format` not recognised. Must be one of: {supported_formats}"
)
@staticmethod
def _infer_steps_from_configs_argument(
configs, preprocessing_or_sorting
) -> dict[str, list]: # TODO: RENAME !
"""
Given the possible arguments for `configs` in `preprocess()`,
infer the `pp_steps` dictionary of preprocessing steps to run.
TODO
----
This function should be moved when it has use outside this class e.g.
when Subject class is introduced.
"""
if not isinstance(configs, dict):
if isinstance(configs, Path) or "/" in configs or "\\" in configs:
settings = config_utils.load_config_dict(configs)[
preprocessing_or_sorting
]
else:
pp_steps, sorting = config_utils.get_configs(configs)
settings = (
pp_steps if preprocessing_or_sorting == "preprocessing" else sorting
)
else:
# maybe the user did not include the "preprocessing" or "sorting" top level
settings = configs.get(preprocessing_or_sorting, configs)
return settings