How to use the spikeextractors.load_extractor_from_dict function in spikeextractors

To help you get started, we’ve selected a few spikeextractors examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _extract_amplitudes_one_unit(unit, rec_arg, sort_arg, channel_ids, max_spikes_per_unit, frames_before, frames_after,
                                 peak, method, seed, memmap_array=None):
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    if isinstance(sort_arg, dict):
        sorting = se.load_extractor_from_dict(sort_arg)
    else:
        sorting = sort_arg

    spike_train = sorting.get_unit_spike_train(unit)
    if max_spikes_per_unit < len(spike_train):
        indexes = np.sort(np.random.RandomState(seed=seed).permutation(len(spike_train))[:max_spikes_per_unit])
    else:
        indexes = np.arange(len(spike_train))
    spike_train = spike_train[indexes]

    snippets = recording.get_snippets(reference_frames=spike_train,
                                      snippet_len=[frames_before, frames_after], channel_ids=channel_ids)
    if peak == 'both':
        amps = np.max(np.abs(snippets), axis=-1)
        if len(amps.shape) > 1:
            amps = np.max(amps, axis=-1)
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _extract_waveforms_one_unit(unit, rec_arg, sort_arg, channel_ids, unit_ids, grouping_property,
                                compute_property_from_recording, max_channels_per_waveforms, max_spikes_per_unit,
                                n_pad, dtype, seed, verbose, memmap_array=None):
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    if isinstance(sort_arg, dict):
        sorting = se.load_extractor_from_dict(sort_arg)
    else:
        sorting = sort_arg

    if grouping_property is not None:
        if grouping_property not in recording.get_shared_channel_property_names():
            raise ValueError("'grouping_property' should be a property of recording extractors")
        if compute_property_from_recording:
            compute_sorting_group = True
        elif grouping_property not in sorting.get_shared_unit_property_names():
            warnings.warn('Grouping property not in sorting extractor. Computing it from the recording extractor')
            compute_sorting_group = True
        else:
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _extract_amplitudes_one_unit(unit, rec_arg, sort_arg, channel_ids, max_spikes_per_unit, frames_before, frames_after,
                                 peak, method, seed, memmap_array=None):
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    if isinstance(sort_arg, dict):
        sorting = se.load_extractor_from_dict(sort_arg)
    else:
        sorting = sort_arg

    spike_train = sorting.get_unit_spike_train(unit)
    if max_spikes_per_unit < len(spike_train):
        indexes = np.sort(np.random.RandomState(seed=seed).permutation(len(spike_train))[:max_spikes_per_unit])
    else:
        indexes = np.arange(len(spike_train))
    spike_train = spike_train[indexes]

    snippets = recording.get_snippets(reference_frames=spike_train,
                                      snippet_len=[frames_before, frames_after], channel_ids=channel_ids)
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _extract_waveforms_one_unit(unit, rec_arg, sort_arg, channel_ids, unit_ids, grouping_property,
                                compute_property_from_recording, max_channels_per_waveforms, max_spikes_per_unit,
                                n_pad, dtype, seed, verbose, memmap_array=None):
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    if isinstance(sort_arg, dict):
        sorting = se.load_extractor_from_dict(sort_arg)
    else:
        sorting = sort_arg

    if grouping_property is not None:
        if grouping_property not in recording.get_shared_channel_property_names():
            raise ValueError("'grouping_property' should be a property of recording extractors")
        if compute_property_from_recording:
            compute_sorting_group = True
        elif grouping_property not in sorting.get_shared_unit_property_names():
            warnings.warn('Grouping property not in sorting extractor. Computing it from the recording extractor')
            compute_sorting_group = True
        else:
            compute_sorting_group = False

        if not compute_sorting_group:
            rec_list, rec_props = recording.get_sub_extractors_by_property(grouping_property,
github SpikeInterface / spiketoolkit / spiketoolkit / sortingcomponents / detection.py View on Github external
def _detect_and_align_peaks_single_channel(rec_arg, channel, n_std, detect_sign, n_pad, upsample, min_diff_samples,
                                           align, verbose):
    if verbose:
        print(f'Detecting spikes on channel {channel}')
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    trace = np.squeeze(recording.get_traces(channel_ids=channel))
    if detect_sign == -1:
        thresh = -n_std * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where(trace < thresh)[0]
    elif detect_sign == 1:
        thresh = n_std * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where(trace > thresh)[0]
    else:
        thresh = n_std * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where((trace > thresh) | (trace < -thresh))[0]
    intervals = np.diff(idx_spikes)
    sp_times = []

    for i_t, diff in enumerate(intervals):
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
def _extract_activity_one_channel(rec_arg, ch, detect_sign, detect_threshold, verbose):
    if isinstance(rec_arg, dict):
        recording = se.load_extractor_from_dict(rec_arg)
    else:
        recording = rec_arg
    if verbose:
        print(f'Detecting spikes on channel {ch}')
    trace = np.squeeze(recording.get_traces(channel_ids=ch))
    if detect_sign == -1:
        thresh = -detect_threshold * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where(trace < thresh)
    elif detect_sign == 1:
        thresh = detect_threshold * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where(trace > thresh)
    else:
        thresh = detect_threshold * np.median(np.abs(trace) / 0.6745)
        idx_spikes = np.where((trace > thresh) | (trace < -thresh))
    if len(idx_spikes) > 0:
        activity = len(idx_spikes[0])