How to use the spikeextractors.SubRecordingExtractor function in spikeextractors

To help you get started, we’ve selected a few spikeextractors examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SpikeInterface / spiketoolkit / spiketoolkit / preprocessing / remove_bad_channels.py View on Github external
def _initialize_subrecording_extractor(self):
        if isinstance(self._bad_channel_ids, (list, np.ndarray)):
            active_channels = []
            for chan in self._recording.get_channel_ids():
                if chan not in self._bad_channel_ids:
                    active_channels.append(chan)
            self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
        elif self._bad_channel_ids is None:
            start_frame = self._recording.get_num_frames() // 2
            end_frame = int(start_frame + self._seconds * self._recording.get_sampling_frequency())
            if end_frame > self._recording.get_num_frames():
                end_frame = self._recording.get_num_frames()
            traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
            stds = np.std(traces, axis=1)
            bad_channel_ids = [ch for ch, std in enumerate(stds) if std > self._bad_threshold * np.median(stds)]
            if self.verbose:
                print('Automatically removing channels:', bad_channel_ids)
            active_channels = []
            for chan in self._recording.get_channel_ids():
                if chan not in bad_channel_ids:
                    active_channels.append(chan)
            self._subrecording = SubRecordingExtractor(self._recording, channel_ids=active_channels)
        else:
github SpikeInterface / spikeextractors / tests / test_extractors.py View on Github external
def test_multi_sub_recording_extractor(self):
        RX_multi = se.MultiRecordingTimeExtractor(
            recordings=[self.RX, self.RX, self.RX],
            epoch_names=['A', 'B', 'C']
        )
        RX_sub = RX_multi.get_epoch('C')
        self._check_recordings_equal(self.RX, RX_sub)
        self.assertEqual(4, len(RX_sub.get_channel_ids()))

        RX_multi = se.MultiRecordingChannelExtractor(
            recordings=[self.RX, self.RX2, self.RX3],
            groups=[1, 2, 3]
        )
        print(RX_multi.get_channel_groups())
        RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3])
        self._check_recordings_equal(self.RX2, RX_sub)
        self.assertEqual([2, 2, 2, 2], RX_sub.get_channel_groups())
        self.assertEqual(12, len(RX_multi.get_channel_ids()))
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / utils.py View on Github external
def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, max_channels_per_waveforms):
    if grouping_property is None:
        if max_channels_per_waveforms is None:
            n_channels = len(channel_ids)
        elif max_channels_per_waveforms >= len(channel_ids):
            n_channels = len(channel_ids)
        else:
            n_channels = max_channels_per_waveforms
    else:
        rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids)
        rec_groups = np.array([rec.get_channel_property(ch, grouping_property) for ch in rec.get_channel_ids()])
        groups, count = np.unique(rec_groups, return_counts=True)
        if max_channels_per_waveforms is None:
            n_channels = np.max(count)
        elif max_channels_per_waveforms >= np.max(count):
            n_channels = np.max(count)
        else:
            n_channels = max_channels_per_waveforms
    return n_channels
github SpikeInterface / spiketoolkit / spiketoolkit / postprocessing / postprocessing_tools.py View on Github external
memmap_array[:] = wf
                                else:
                                    # some channels are missing - re-instantiate object
                                    memmap_file = memmap_array.filename
                                    del memmap_array
                                    memmap_array = np.memmap(memmap_file, mode='w+', shape=wf.shape, dtype=wf.dtype)
                                    memmap_array[:] = wf
                                waveforms = memmap_array
                            return waveforms, list(indexes), list(max_channel_idxs)
        else:
            for i, unit_id in enumerate(unit_ids):
                if unit == unit_id:
                    if channel_ids is None:
                        channel_ids = recording.get_channel_ids()

                    rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids)
                    rec_groups = np.array(rec.get_channel_groups())
                    groups, count = np.unique(rec_groups, return_counts=True)
                    if max_channels_per_waveforms is None:
                        max_channels_per_waveforms = np.max(count)
                    elif max_channels_per_waveforms >= np.max(count):
                        max_channels_per_waveforms = np.max(count)

                    if max_spikes_per_unit is None:
                        max_spikes = len(sorting.get_unit_spike_train(unit_id))
                    else:
                        max_spikes = max_spikes_per_unit

                    if verbose:
                        print('Waveform ' + str(i + 1) + '/' + str(len(unit_ids)))
                    wf, indexes = _get_random_spike_waveforms(recording=recording,
                                                              sorting=sorting,