How to use the spikeextractors.SortingExtractor function in spikeextractors

To help you get started, we’ve selected a few spikeextractors examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SpikeInterface / spiketoolkit / spiketoolkit / tools.py View on Github external
def exportToPhy(recording, sorting, output_folder, nPCchan=3, nPC=5, filter=False, electrode_dimensions=None,
                max_num_waveforms=np.inf):

    analyzer = Analyzer(recording, sorting)

    if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor):
        raise AttributeError()
    output_folder = os.path.abspath(output_folder)
    if not os.path.isdir(output_folder):
        os.makedirs(output_folder)

    if filter:
        recording = bandpass_filter(recording, freq_min=300, freq_max=6000)

    # save dat file
    se.writeBinaryDatFormat(recording, join(output_folder, 'recording.dat'), dtype='int16')

    # write params.py
    with open(join(output_folder, 'params.py'), 'w') as f:
        f.write("dat_path ="  + "'" + join(output_folder, 'recording.dat') +"'" + '\n')
        f.write('n_channels_dat = ' + str(recording.getNumChannels()) + '\n')
        f.write("dtype = 'int16'\n")
github SpikeInterface / spikeextractors / spikeextractors / extractors / mdaextractors / mdaextractors.py View on Github external
dtype = 'int16'

        with save_file_path.open('wb') as f:
            header = MdaHeader(dt0=dtype, dims0=(num_chan, num_frames))
            header.write(f)
            # takes care of the chunking
            write_to_binary_dat_format(recording, file_handle=f, dtype=dtype, chunk_size=chunk_size,
                                       chunk_mb=chunk_mb)

        params["samplerate"] = recording.get_sampling_frequency()
        with (parent_dir / params_fname).open('w') as f:
            json.dump(params, f)
        np.savetxt(str(parent_dir / geom_fname), geom, delimiter=',')


class MdaSortingExtractor(SortingExtractor):
    extractor_name = 'MdaSortingExtractor'
    installed = True  # check at class level if installed or not
    is_writable = True
    mode = 'file'
    installation_mesg = ""  # error message when not installed

    def __init__(self, file_path, sampling_frequency=None):

        SortingExtractor.__init__(self)
        self._firings_path = file_path
        self._firings = readmda(self._firings_path)
        self._max_channels = self._firings[0, :]
        self._times = self._firings[1, :]
        self._labels = self._firings[2, :]
        self._unit_ids = np.unique(self._labels).astype(int)
        self._sampling_frequency = sampling_frequency
github SpikeInterface / spikeextractors / spikeextractors / extractors / matsortingextractor / matsortingextractor.py View on Github external
try:
    from scipy.io.matlab import loadmat, savemat

    HAVE_LOADMAT = True
except ImportError:
    HAVE_LOADMAT = False

HAVE_MAT = HAVE_H5PY & HAVE_LOADMAT

from spikeextractors import SortingExtractor

PathType = Union[str, Path]


class MATSortingExtractor(SortingExtractor):
    extractor_name = "MATSortingExtractor"
    installed = HAVE_MAT  # check at class level if installed or not
    is_writable = False
    mode = "file"
    installation_mesg = "To use the MATSortingExtractor install h5py and scipy: \n\n pip install h5py scipy\n\n"  # error message when not installed

    def __init__(self, file_path: PathType):
        assert HAVE_MAT, self.installation_mesg
        super().__init__()

        file_path = Path(file_path) if isinstance(file_path, str) else file_path
        if not isinstance(file_path, Path):
            raise TypeError(f"Expected a str or Path file_path but got '{type(file_path).__name__}'")

        file_path = file_path.resolve()  # get absolute path to this file
        if not file_path.is_file():
github SpikeInterface / spikeextractors / spikeextractors / extractors / npzsortingextractor / npzsortingextractor.py View on Github external
from spikeextractors import SortingExtractor
from pathlib import Path

import numpy as np


class NpzSortingExtractor(SortingExtractor):
    """
    Dead simple format super light base on the NPZ numpy format.
    https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html#numpy.savez

    It is in fact an arichive of several .npy format.
    All spike are store in two columns maner index+labels


    """
    extractor_name = 'NpzSortingExtractor'
    exporter_name = 'NpzSortingExporter'
    exporter_gui_params = [
        {'name': 'save_path', 'type': 'file', 'title': "Save path (.npz)"},
    ]
    installed = True # depend only on numpy
    installation_mesg = "Always installed"
github SpikeInterface / spikeextractors / spikeextractors / extractors / phyextractors / phyextractors.py View on Github external
if (phy_folder / 'channel_groups.npy').is_file():
            channel_groups = np.load(phy_folder / 'channel_groups.npy')
            assert len(channel_groups) == self.get_num_channels()
            for (ch, cg) in zip(self.get_channel_ids(), channel_groups):
                self.set_channel_property(ch, 'group', cg)

        if (phy_folder / 'channel_positions.npy').is_file():
            channel_locations = np.load(phy_folder / 'channel_positions.npy')
            assert len(channel_locations) == self.get_num_channels()
            for (ch, loc) in zip(self.get_channel_ids(), channel_locations):
                self.set_channel_property(ch, 'location', loc)

        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}


class PhySortingExtractor(SortingExtractor):

    extractor_name = 'PhySortingExtractor'
    exporter_name = 'PhySortingExporter'
    exporter_gui_params = [
        {'name': 'save_path', 'type': 'folder', 'title': "Save path"},
    ]
    installed = True  # check at class level if installed or not
    is_writable = True
    mode = 'folder'
    installation_mesg = ""  # error message when not installed

    def __init__(self, folder_path, exclude_cluster_groups=None, load_waveforms=False, verbose=False):
        SortingExtractor.__init__(self)
        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
github SpikeInterface / spikeextractors / spikeextractors / extractors / phyextractors / phyextractors.py View on Github external
def __init__(self, folder_path, exclude_cluster_groups=None, load_waveforms=False, verbose=False):
        SortingExtractor.__init__(self)
        phy_folder = Path(folder_path)

        spike_times = np.load(phy_folder / 'spike_times.npy')
        spike_templates = np.load(phy_folder / 'spike_templates.npy')

        if (phy_folder /'spike_clusters.npy').is_file():
            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')
        else:
            spike_clusters = spike_templates

        if (phy_folder / 'amplitudes.npy').is_file():
            amplitudes = np.load(phy_folder / 'amplitudes.npy')
        else:
            amplitudes = np.ones(len(spike_times))

        if (phy_folder /'pc_features.npy').is_file():
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _get_phy_data(recording, sorting, compute_pc_features, compute_amplitudes,
                  max_channels_per_template, **kwargs):
    if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor):
        raise AttributeError()
    if len(sorting.get_unit_ids()) == 0:
        raise Exception("No units in the sorting result, can't compute phy information.")

    params_dict = update_all_param_dicts_with_kwargs(kwargs)
    n_comp = params_dict['n_comp']
    max_spikes_for_pca = params_dict['max_spikes_for_pca']
    recompute_info = params_dict['recompute_info']
    save_property_or_features = params_dict['save_property_or_features']
    verbose = params_dict['verbose']
    grouping_property = params_dict['grouping_property']
    ms_before = params_dict['ms_before']
    ms_after = params_dict['ms_after']
    dtype = params_dict['dtype']
    memmap = params_dict['memmap']
    n_jobs = params_dict['n_jobs']
github SpikeInterface / spiketoolkit / spiketoolkit / validation / metric_calculator.py View on Github external
sorting.threshold_sorting(0, "less_or_equal")

        if unit_ids is None:
            unit_ids = sorting.get_unit_ids()
        else:
            unit_ids = set(unit_ids)
            unit_ids = list(unit_ids.intersection(sorting.get_unit_ids()))

        if len(unit_ids) == 0:
            raise ValueError("No units found.")

        spike_times, spike_clusters = get_spike_times_metrics_data(
            sorting, self._sampling_frequency
        )
        assert isinstance(
            sorting, SortingExtractor
        ), "'sorting' must be  a SortingExtractor object"
        self._sorting = sorting
        self._set_unit_ids(unit_ids)
        self._set_epochs(epoch_tuples, epoch_names)
        self._spike_times = spike_times
        self._spike_clusters = spike_clusters
        self._total_units = len(unit_ids)
        self._unit_indices = _get_unit_indices(self._sorting, unit_ids)
        # To compute this data, need to call all metric data
        self._amplitudes = None
        self._pc_features = None
        self._pc_feature_ind = None
        self._spike_clusters_pca = None
        self._spike_clusters_amps = None
        self._spike_times_pca = None
        self._spike_times_amps = None