Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_serialized_dict(self, include_properties=None, include_features=None):
'''
Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an
extractor with spikeextractors.load_extractor_from_dict(dump_dict)
Returns
-------
dump_dict: dict
Serialized dictionary
include_properties: list or None
List of properties to include in the dictionary
include_features: list or None
List of features to include in the dictionary
'''
class_name = str(BinDatRecordingExtractor).replace("", '')
module = class_name.split('.')[0]
imported_module = importlib.import_module(module)
if self._is_tmp:
print("Warning: dumping a CacheRecordingExtractor. The path to the tmp binary file will be lost in "
"further sessions. To prevent this, use the 'CacheRecordingExtractor.move_to('path-to-file)' "
"function")
dump_dict = {'class': class_name, 'module': module, 'kwargs': self._bindat_kwargs,
'key_properties': self._key_properties, 'version': imported_module.__version__, 'dumpable': True}
return dump_dict
from spikeextractors import SortingExtractor
from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor
from spikeextractors.extraction_tools import read_python
import numpy as np
from pathlib import Path
try:
import h5py
HAVE_KLSX = True
except ImportError:
HAVE_KLSX = False
# noinspection SpellCheckingInspection
class KlustaRecordingExtractor(BinDatRecordingExtractor):
extractor_name = 'KlustaRecordingExtractor'
has_default_locations = False
installed = HAVE_KLSX # check at class level if installed or not
is_writable = True
mode = 'folder'
installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed
def __init__(self, folder_path):
assert HAVE_KLSX, "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n"
klustafolder = Path(folder_path).absolute()
config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]
dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0]
assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder"
config = read_python(str(config_file))
sampling_frequency = config['traces']['sample_rate']
n_channels = config['traces']['n_channels']
from spikeextractors import SortingExtractor
from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor
from spikeextractors.extraction_tools import read_python, check_valid_unit_id
import numpy as np
from pathlib import Path
try:
import h5py
HAVE_KLSX = True
except ImportError:
HAVE_KLSX = False
# noinspection SpellCheckingInspection
class KlustaRecordingExtractor(BinDatRecordingExtractor):
extractor_name = 'KlustaRecordingExtractor'
has_default_locations = False
installed = HAVE_KLSX # check at class level if installed or not
is_writable = True
mode = 'folder'
installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed
def __init__(self, folder_path):
assert HAVE_KLSX, self.installation_mesg
klustafolder = Path(folder_path).absolute()
config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]
dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0]
assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder"
config = read_python(str(config_file))
sampling_frequency = config['traces']['sample_rate']
n_channels = config['traces']['n_channels']
original_units = self._unit_ids
self._unit_ids = included_units
# set features
self._spiketrains = []
for clust in self._unit_ids:
idx = np.where(spike_clusters == clust)[0]
self._spiketrains.append(spike_times[idx])
self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx])
if pc_features is not None:
self.set_unit_spike_features(clust, 'pc_features', pc_features[idx])
if load_waveforms:
datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin']
recording = BinDatRecordingExtractor(datfile[0], sampling_frequency=float(self.params['sample_rate']),
dtype=self.params['dtype'], numchan=self.params['n_channels_dat'])
# if channel groups are present, compute waveforms by group
if (phy_folder / 'channel_groups.npy').is_file():
channel_groups = np.load(phy_folder / 'channel_groups.npy')
assert len(channel_groups) == recording.get_num_channels()
for (ch, cg) in zip(recording.get_channel_ids(), channel_groups):
recording.set_channel_property(ch, 'group', cg)
for u_i, u in enumerate(self.get_unit_ids()):
if verbose:
print('Computing waveform by group for unit', u)
frames_before = int(0.5 / 1000. * recording.get_sampling_frequency())
frames_after = int(2 / 1000. * recording.get_sampling_frequency())
spiketrain = self.get_unit_spike_train(u)
if 'group' in self.get_unit_property_names(u):
group_idx = np.where(channel_groups == int(self.get_unit_property(u, 'group')))[0]
wf = recording.get_snippets(reference_frames=spiketrain,
from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor
from spikeextractors.extractors.npzsortingextractor import NpzSortingExtractor
from spikeextractors import RecordingExtractor, SortingExtractor
import tempfile
from pathlib import Path
from copy import deepcopy
import importlib
import os
import shutil
class CacheRecordingExtractor(BinDatRecordingExtractor, RecordingExtractor):
def __init__(self, recording, chunk_size=None, save_path=None):
RecordingExtractor.__init__(self) # init tmp folder before constructing BinDatRecordingExtractor
tmp_folder = self.get_tmp_folder()
self._recording = recording
if save_path is None:
self._is_tmp = True
self._tmp_file = tempfile.NamedTemporaryFile(suffix=".dat", dir=tmp_folder).name
else:
save_path = Path(save_path)
if save_path.suffix != '.dat' and save_path.suffix != '.bin':
save_path = save_path.with_suffix('.dat')
if not save_path.parent.is_dir():
os.makedirs(save_path.parent)
self._is_tmp = False
self._tmp_file = save_path
self._dtype = recording.get_dtype()
import os
from pathlib import Path
import numpy as np
from spikeextractors import SortingExtractor
from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor
from spikeextractors.extraction_tools import save_to_probe_file, load_probe_file, check_valid_unit_id
try:
import hybridizer.io as sbio
import hybridizer.probes as sbprb
HAVE_SBEX = True
except ImportError:
HAVE_SBEX = False
class SHYBRIDRecordingExtractor(BinDatRecordingExtractor):
extractor_name = 'SHYBRIDRecording'
installed = HAVE_SBEX
is_writable = True
mode = 'file'
installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n"
def __init__(self, file_path):
# load params file related to the given shybrid recording
assert HAVE_SBEX, self.installation_mesg
params = sbio.get_params(file_path)['data']
# create a shybrid probe object
probe = sbprb.Probe(params['probe'])
nb_channels = probe.total_nb_channels
# translate the byte ordering
dtype: dtype
Type of the saved data. Default float32.
"""
assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg
RECORDING_NAME = 'recording.bin'
PROBE_NAME = 'probe.prb'
PARAMETERS_NAME = 'recording.yml'
# location information has to be present in order for shybrid to
# be able to operate on the recording
if 'location' not in recording.get_shared_channel_property_names():
raise GeometryNotLoadedError("Channel locations were not found")
# write recording
recording_fn = os.path.join(save_path, RECORDING_NAME)
BinDatRecordingExtractor.write_recording(recording, recording_fn,
time_axis=0, dtype=dtype)
# write probe file
probe_fn = os.path.join(save_path, PROBE_NAME)
save_to_probe_file(recording, probe_fn)
# create parameters file
parameters = params_template.format(initial_sorting_fn=initial_sorting_fn,
data_type=dtype,
sampling_frequency=str(recording.get_sampling_frequency()),
byte_ordering='F',
probe_fn=probe_fn)
# write parameters file
parameters_fn = os.path.join(save_path, PARAMETERS_NAME)
with open(parameters_fn, 'w') as fp: