Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sum_waveform(records, peak_left, peak_length):
# Make a single big peak to contain all the records
n_ch = 100
peaks = np.zeros(1, strax.peak_dtype(n_ch, n_sum_wv_samples=200))
p = peaks[0]
p['time'] = peak_left
p['length'] = peak_length
p['dt'] = 0
strax.sum_waveform(peaks, records, np.ones(n_ch))
# Area measures must be consistent
area = p['area']
assert area >= 0
assert p['data'].sum() == area
assert p['area_per_channel'].sum() == area
# Create a simple sum waveform
if not len(records):
max_sample = 3 # Whatever
return r
class SomeCrash(Exception):
pass
@strax.takes_config(
strax.Option('base_area', default=0),
strax.Option('give_wrong_dtype', default=False),
strax.Option('bonus_area', default_by_run=[(0, 0), (1, 1)]))
class Peaks(strax.Plugin):
provides = 'peaks'
data_kind = 'peaks'
depends_on = ('records',)
dtype = strax.peak_dtype()
parallel = True
def compute(self, records):
if self.config['give_wrong_dtype']:
return np.zeros(5, [('a', np.int), ('b', np.float)])
p = np.zeros(len(records), self.dtype)
p['time'] = records['time']
p['length'] = p['dt'] = 1
p['area'] = self.config['base_area'] + self.config['bonus_area']
return p
# Another peak-kind plugin, to test time_range selection
# with unaligned chunks
class PeakClassification(strax.Plugin):
provides = 'peak_classification'
r['dt'] = 1
r['channel'] = np.arange(len(r))
return r
def source_finished(self):
return True
def is_ready(self, chunk_i):
return chunk_i < n_chunks
class Peaks(strax.Plugin):
parallel = True
provides = 'peaks'
depends_on = ('records',)
dtype = strax.peak_dtype()
def compute(self, records):
assert isinstance(records, np.ndarray), \
f"Recieved {type(records)} instead of numpy array!"
p = np.zeros(len(records), self.dtype)
p['time'] = records['time']
return p
def test_processing():
"""Test ParallelSource plugin under several conditions"""
# It's always harder with a small mailbox:
strax.Mailbox.DEFAULT_MAX_MESSAGES = 2
for request_peaks in (True, False):
for peaks_parallel in (True, False):
for max_workers in (1, 2):
for peaks_parallel in (True, False):
for max_workers in (1, 2):
Peaks.parallel = peaks_parallel
print(f"\nTesting with request_peaks {request_peaks}, "
f"peaks_parallel {peaks_parallel}, "
f"max_workers {max_workers}")
mystrax = strax.Context(storage=[],
register=[Records, Peaks])
bla = mystrax.get_array(
run_id=run_id,
targets='peaks' if request_peaks else 'records',
max_workers=max_workers)
assert len(bla) == recs_per_chunk * n_chunks
assert bla.dtype == (
strax.peak_dtype() if request_peaks else strax.record_dtype())
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
args_options,
_result_buffer=None, result_dtype=None):
# TODO NEEDS TESTS!
new_peaks = _result_buffer
offset = 0
for p_i, p in enumerate(peaks):
if p['area'] < min_area:
continue
prev_split_i = 0
w = p['data'][:p['length']]
for split_i, bonus_output in split_finder(
r = strax.exclude_tails(raw_records, to_pe)
hits = strax.find_hits(r)
strax.cut_outside_hits(r, hits)
return r
@export
@strax.takes_config(
strax.Option('diagnose_sorting', track=False, default=False,
help="Enable runtime checks for sorting and disjointness"))
class Peaks(strax.Plugin):
depends_on = ('records',)
data_kind = 'peaks'
parallel = True
rechunk_on_save = True
dtype = strax.peak_dtype(n_channels=len(to_pe))
def compute(self, records):
r = records
hits = strax.find_hits(r) # TODO: Duplicate work
hits = strax.sort_by_time(hits)
peaks = strax.find_peaks(hits, to_pe,
result_dtype=self.dtype)
strax.sum_waveform(peaks, r, to_pe)
peaks = strax.split_peaks(peaks, r, to_pe)
strax.compute_widths(peaks)
if self.config['diagnose_sorting']:
assert np.diff(r['time']).min() >= 0, "Records not sorted"