How to use the spikeextractors.extraction_tools.check_get_traces_args function in spikeextractors

To help you get started, we’ve selected a few spikeextractors examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SpikeInterface / spikeextractors / spikeextractors / multirecordingtimeextractor.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        recording1, i_sec1, i_start_frame = self._find_section_for_frame(start_frame)
        _, i_sec2, i_end_frame = self._find_section_for_frame(end_frame)
        if i_sec1 == i_sec2:
            return recording1.get_traces(channel_ids=channel_ids, start_frame=i_start_frame, end_frame=i_end_frame)
        traces = []
        traces.append(
            self._recordings[i_sec1].get_traces(channel_ids=channel_ids, start_frame=i_start_frame,
                                         end_frame=self._recordings[i_sec1].get_num_frames())
        )
        for i_sec in range(i_sec1 + 1, i_sec2):
            traces.append(
                self._recordings[i_sec].get_traces(channel_ids=channel_ids)
            )
        traces.append(
            self._recordings[i_sec2].get_traces(channel_ids=channel_ids, start_frame=0, end_frame=i_end_frame)
github SpikeInterface / spiketoolkit / spiketoolkit / preprocessing / resample.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        start_frame_not_sampled = int(start_frame / self.get_sampling_frequency() *
                                      self._recording.get_sampling_frequency())
        start_frame_sampled = start_frame
        end_frame_not_sampled = int(end_frame / self.get_sampling_frequency() *
                                    self._recording.get_sampling_frequency())
        end_frame_sampled = end_frame
        traces = self._recording.get_traces(start_frame=start_frame_not_sampled,
                                            end_frame=end_frame_not_sampled,
                                            channel_ids=channel_ids)
        if np.mod(self._recording.get_sampling_frequency(), self._resample_rate) == 0:
            traces_resampled = signal.decimate(traces,
                                               q=int(self._recording.get_sampling_frequency() / self._resample_rate),
                                               axis=1)
        else:
            traces_resampled = signal.resample(traces, int(end_frame_sampled - start_frame_sampled), axis=1)
github SpikeInterface / spiketoolkit / spiketoolkit / preprocessing / normalize_by_quantile.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        traces = self._recording.get_traces(channel_ids=channel_ids,
                                            start_frame=start_frame,
                                            end_frame=end_frame)
        return traces * self._scalar + self._offset
github SpikeInterface / spikeextractors / spikeextractors / extractors / neoextractors / neobaseextractor.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        # in neo rawio channel can acces by names/ids/indexes
        # there is no garranty that ids/names are unique on some formats
        raw_traces = self.neo_reader.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
                                                            i_start=start_frame, i_stop=end_frame,
                                                            channel_indexes=None, channel_names=None,
                                                            channel_ids=channel_ids)

        # rescale traces to natural units (can be anything)
        scaled_traces = self.neo_reader.rescale_signal_raw_to_float(raw_traces, dtype='float32',
                                                                    channel_indexes=None, channel_names=None,
                                                                    channel_ids=channel_ids)
        channel_idxs = np.array([list(channel_ids).index(ch) for ch in channel_ids])
        # and then to uV
        scaled_traces *= self.additional_gain[:, channel_idxs]
github SpikeInterface / spikeextractors / spikeextractors / extractors / biocamrecordingextractor / biocamrecordingextractor.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        data = self._read_function(self._rf, start_frame, end_frame, self.get_num_channels())
        # transform to slice if possible
        if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1):
            channel_ids = slice(channel_ids[0], channel_ids[0]+len(channel_ids))
        return data[:, channel_ids].T
github SpikeInterface / spikeextractors / spikeextractors / extractors / mea1krecordingextractor / mea1krecordingextractor.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        if np.array(channel_ids).size > 1:
            assert np.all([ch in self.get_channel_ids() for ch in channel_ids])
            if np.any(np.diff(channel_ids) < 0):
                sorted_idx = np.argsort(channel_ids)
                recordings = self._signals[np.sort(channel_ids), start_frame:end_frame]
                return recordings[sorted_idx].astype('float')
            else:
                return self._signals[np.array(channel_ids), start_frame:end_frame].astype('float')
        else:
            assert channel_ids in self.get_channel_ids()
            return self._signals[np.array(channel_ids), start_frame:end_frame].astype('float')
github SpikeInterface / spikeextractors / spikeextractors / extractors / openephysextractors / openephysextractors.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        if self._dtype == 'int16':
            return self._recording.analog_signals[0].signal[channel_ids, start_frame:end_frame]
        elif self._dtype == 'float':
            return self._recording.analog_signals[0].signal[channel_ids, start_frame:end_frame] * \
                   self._recording.analog_signals[0].gain
github SpikeInterface / spikeextractors / spikeextractors / extractors / mcsh5recordingextractor / mcsh5recordingextractor.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)
        if start_frame is None:
            start_frame = 0
        if end_frame is None:
            end_frame = self.get_num_frames()

        channel_idxs = []
        for m in channel_ids:
            assert m in self._channel_ids, 'channel_id {} not found'.format(m)
            channel_idxs.append(np.where(np.array(self._channel_ids) == m)[0][0])

        stream = self._rf.require_group('/Data/Recording_0/AnalogStream/Stream_' + str(self._stream_id))
        conv = self._convFact.astype(float) * (10.0 ** self._exponent)

        if np.array(channel_idxs).size > 1:
github SpikeInterface / spiketoolkit / spiketoolkit / preprocessing / common_reference.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])
        if self._ref == 'median':
            if self._groups is None:
                if self.verbose:
                    print('Common median reference using all channels')
                traces = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame)
                traces = traces - np.median(traces, axis=0, keepdims=True)
                return traces[channel_idxs].astype(self._dtype)
            else:
                new_groups = []
                for g in self._groups:
                    new_chans = []
                    for chan in g:
                        if chan in self._recording.get_channel_ids():
                            new_chans.append(chan)
github SpikeInterface / spikeextractors / spikeextractors / extractors / nwbextractors / nwbextractors.py View on Github external
    @check_get_traces_args
    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None):
        with NWBHDF5IO(self._path, 'r') as io:
            nwbfile = io.read()
            es = nwbfile.acquisition[self._electrical_series_name]
            es_channel_ids = np.array(es.electrodes.table.id[:])[es.electrodes.data[:]].tolist()
            table_ids = [es_channel_ids.index(id) for id in channel_ids]
            if np.array(channel_ids).size > 1 and np.any(np.diff(channel_ids) < 0):
                sorted_idx = np.argsort(table_ids)
                recordings = es.data[start_frame:end_frame, np.sort(table_ids)].T
                traces = recordings[sorted_idx, :]
            else:
                traces = es.data[start_frame:end_frame, table_ids].T
            # This DatasetView and lazy operations will only work within context
            # We're keeping the non-lazy version for now
            # es_view = DatasetView(es.data)  # es is an instantiated h5py dataset
            # traces = es_view.lazy_slice[start_frame:end_frame, channel_ids].lazy_transpose()