Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
##############################################################################
# plot_agreement_matrix()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
w_agr_MS4 = sw.plot_agreement_matrix(comp_MS4, count_text=False)
##############################################################################
# plot_sorting_performance()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We can also plot a performance metric (e.g. accuracy, recall, precision) with respect to a quality metric, for
# example signal-to-noise ratio. Quality metrics can be computed using the :code:`toolkit.validation` submodule
import spikeinterface.toolkit as st
snrs = st.validation.compute_snrs(sorting_true, recording, save_as_property=True)
w_perf = sw.plot_sorting_performance(comp_MS4, property_name='snr', metric='accuracy')
##############################################################################
# Widgets using MultiSortingComparison
# -------------------------------------
#
# We can also compare all three SortingExtractor objects, obtaining a :code:`MultiSortingComparison` object.
multicomp = sc.compare_multiple_sorters([sorting_true, sorting_MS4, sorting_KL])
##############################################################################
# plot_multicomp_graph()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# The :code:`get_metrics_dict` and :code:`get_metrics_df` return all metrics as a dictionary or a pandas dataframe:
print(mc.get_metrics_dict())
print(mc.get_metrics_df())
##############################################################################
# If you don't need to compute all metrics, you can either pass a 'metric_names' list to the :code:`compute_metrics` or
# call separate methods for computing single metrics:
# This only compute signal-to-noise ratio (SNR)
mc.compute_metrics(metric_names=['snr'])
print(mc.get_metrics_df()['snr'])
# This function also returns the SNRs
snrs = st.validation.compute_snrs(sorting, recording)
print(snrs)
##############################################################################
# Once we have paired :code:`RecordingExtractor` and :code:`SortingExtractor` objects we can post-process, validate, and curate the
# results. With the :code:`toolkit.postprocessing` submodule, one can, for example, get waveforms, templates, maximum
# channels, PCA scores, or export the data to Phy. `Phy `_ is a GUI for manual curation of the spike sorting output.
# To export to phy you can run:
st.postprocessing.export_to_phy(recording, sorting_KL, output_folder='phy')
##############################################################################
# Then you can run the template-gui with: :code:`phy template-gui phy/params.py` and manually curate the results.
#
# Validation of spike sorting output is very important. The :code:`toolkit.validation` module implements several quality
# metrics to assess the goodness of sorted units. Among those, for example, are signal-to-noise ratio, ISI violation
# ratio, isolation distance, and many more.
snrs = st.validation.compute_snrs(sorting_KL, recording_cmr)
isi_violations = st.validation.compute_isi_violations(sorting_KL)
isolations = st.validation.compute_isolation_distances(sorting_KL, recording)
print('SNR', snrs)
print('ISI violation ratios', isi_violations)
print('Isolation distances', isolations)
##############################################################################
# Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select
# sorted units with a SNR above a certain threshold:
sorting_curated_snr = st.curation.threshold_snr(sorting_KL, recording, threshold=5, threshold_sign='less')
snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr)
print('Curated SNR', snrs_above)
sorting_fr = st.curation.threshold_firing_rate(sorting_KL, threshold=2.3, threshold_sign='less')
print('Units after FR theshold:', sorting_fr.get_unit_ids())
print('Number of units after FR theshold:', len(sorting_fr.get_unit_ids()))
sorting_snr = st.curation.threshold_snr(sorting_fr, recording, threshold=10, threshold_sign='less')
print('Units after SNR theshold:', sorting_snr.get_unit_ids())
print('Number of units after SNR theshold:', len(sorting_snr.get_unit_ids()))
##############################################################################
# Let's now check with the :code:`toolkit.validation` submodule that all units have a firing rate > 10 and snr > 0
fr = st.validation.compute_firing_rates(sorting_snr)
snrs = st.validation.compute_snrs(sorting_snr, recording)
print('Firing rates:', fr)
print('SNR:', snrs)
# ratio, isolation distance, and many more.
snrs = st.validation.compute_snrs(sorting_KL, recording_cmr)
isi_violations = st.validation.compute_isi_violations(sorting_KL)
isolations = st.validation.compute_isolation_distances(sorting_KL, recording)
print('SNR', snrs)
print('ISI violation ratios', isi_violations)
print('Isolation distances', isolations)
##############################################################################
# Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select
# sorted units with a SNR above a certain threshold:
sorting_curated_snr = st.curation.threshold_snr(sorting_KL, recording, threshold=5, threshold_sign='less')
snrs_above = st.validation.compute_snrs(sorting_curated_snr, recording_cmr)
print('Curated SNR', snrs_above)
##############################################################################
# The final part of this tutorial deals with comparing spike sorting outputs.
# We can either (1) compare the spike sorting results with the ground-truth sorting :code:`sorting_true`, (2) compare
# the output of two (Klusta and Mountainsor4), or (3) compare the output of multiple sorters:
comp_gt_KL = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_KL)
comp_KL_MS4 = sc.compare_two_sorters(sorting1=sorting_KL, sorting2=sorting_MS4)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_MS4, sorting_KL],
name_list=['klusta', 'ms4'])
##############################################################################
# When comparing with a ground-truth sorting extractor (1), you can get the sorting performance and plot a confusion
# kept:
sorting_fr = st.curation.threshold_firing_rate(sorting_KL, threshold=2.3, threshold_sign='less')
print('Units after FR theshold:', sorting_fr.get_unit_ids())
print('Number of units after FR theshold:', len(sorting_fr.get_unit_ids()))
sorting_snr = st.curation.threshold_snr(sorting_fr, recording, threshold=10, threshold_sign='less')
print('Units after SNR theshold:', sorting_snr.get_unit_ids())
print('Number of units after SNR theshold:', len(sorting_snr.get_unit_ids()))
##############################################################################
# Let's now check with the :code:`toolkit.validation` submodule that all units have a firing rate > 10 and snr > 0
fr = st.validation.compute_firing_rates(sorting_snr)
snrs = st.validation.compute_snrs(sorting_snr, recording)
print('Firing rates:', fr)
print('SNR:', snrs)
"""
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
##############################################################################
# First, let's create a toy example:
recording, sorting = se.example_datasets.toy_example(num_channels=4, duration=10, seed=0)
##############################################################################
# The :code:`toolkit.validation` submodule has a :code:`MetricCalculator` class that enables to compute metrics in a
# compact and easy way. You first need to instantiate a :code:`MetricCalculator` object with the
# :code:`SortingExtractor` and :code:`RecordingExtractor` objects.
mc = st.validation.MetricCalculator(sorting, recording)
##############################################################################
# You can then compute metrics as follows:
mc.compute_metrics()
##############################################################################
# This is the list of the computed metrics:
print(list(mc.get_metrics_dict().keys()))
##############################################################################
# The :code:`get_metrics_dict` and :code:`get_metrics_df` return all metrics as a dictionary or a pandas dataframe:
print(mc.get_metrics_dict())
print(mc.get_metrics_df())