Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_from_historical_tree_sequence(self):
sample_times = np.arange(10)
ts = get_example_historical_sampled_ts(sample_times, 10)
sd1 = formats.SampleData(sequence_length=ts.sequence_length)
self.verify_data_round_trip(ts, sd1)
sd2 = formats.SampleData.from_tree_sequence(ts)
self.assertTrue(sd1.data_equal(sd2))
def test_no_arguments(self):
ts = get_example_ts(10, 10, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
# No arguments gives the same data
subset = sd1.subset()
sd1.assert_data_equal(subset)
subset = sd1.subset(individuals=np.arange(sd1.num_individuals))
sd1.assert_data_equal(subset)
subset = sd1.subset(sites=np.arange(sd1.num_sites))
sd1.assert_data_equal(subset)
self.assertEqual(subset.num_provenances, sd1.num_provenances + 1)
def test_merge_identical(self):
n = 10
ts = get_example_ts(n, 10, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
sd2 = sd1.merge(sd1)
self.assertEqual(sd2.num_sites, sd1.num_sites)
self.assertEqual(sd2.num_samples, 2 * sd1.num_samples)
for var1, var2 in zip(sd1.variants(), sd2.variants()):
self.assertEqual(var1.site, var2.site)
self.assertTrue(np.array_equal(var1.genotypes, var2.genotypes[:n]))
self.assertTrue(np.array_equal(var1.genotypes, var2.genotypes[n:]))
self.verify(sd1, sd1)
self.verify(sd2, sd1)
def test_file_kwargs(self):
# Make sure we pass kwards on to the SampleData constructor as
# required.
ts = get_example_ts(10, 10, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "sample-data")
sd2 = sd1.merge(sd1, path=path)
self.assertTrue(os.path.exists(path))
sd3 = formats.SampleData.load(path)
self.assertTrue(sd2.data_equal(sd3))
tables.mutations.add_row(site=0, node=node, derived_state=str(i + 2))
# Create < 2 alleles by adding a non-variable site at the end
extra_last_pos = (ts.site(ts.num_sites - 1).position + ts.sequence_length) / 2
tables.sites.add_row(position=extra_last_pos, ancestral_state="0")
tables.sort()
tables.build_index()
tables.compute_mutation_parents()
ts = tables.tree_sequence()
self.assertGreater(len(ts.site(0).mutations), 1)
self.assertEqual(len(ts.site(ts.num_sites - 1).mutations), 0)
sd1 = formats.SampleData(sequence_length=ts.sequence_length)
self.verify_data_round_trip(ts, sd1)
num_alleles = sd1.num_alleles()
for var in ts.variants():
self.assertEqual(len(var.alleles), num_alleles[var.site.id])
sd2 = formats.SampleData.from_tree_sequence(ts)
self.assertTrue(sd1.data_equal(sd2))
def test_errors(self):
ts = get_example_ts(10, 10, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
with self.assertRaises(ValueError):
sd1.subset(sites=[])
with self.assertRaises(ValueError):
sd1.subset(individuals=[])
# Individual IDs out of bounds
with self.assertRaises(ValueError):
sd1.subset(individuals=[-1, 0, 1])
with self.assertRaises(ValueError):
sd1.subset(individuals=[10, 0, 1])
# Site IDs out of bounds
with self.assertRaises(ValueError):
sd1.subset(sites=[-1, 0, 1])
with self.assertRaises(ValueError):
sd1.subset(sites=[ts.num_sites, 0, 1])
# Duplicate IDs
def test_different_alleles_same_sites(self):
ts = get_example_individuals_ts_with_metadata(5, 2, 10, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
tables = ts.dump_tables()
tables.mutations.derived_state += 1
sd2 = formats.SampleData.from_tree_sequence(tables.tree_sequence())
self.verify(sd1, sd2)
self.verify(sd2, sd1)
sd3 = sd1.merge(sd2)
for var in sd3.variants():
self.assertEqual(var.site.alleles, ("0", "1", "2"))
def test_mismatch_ancestral_state(self):
# Difference ancestral states
ts = get_example_ts(2, 2, 1)
sd1 = formats.SampleData.from_tree_sequence(ts)
tables = ts.dump_tables()
tables.sites.ancestral_state += 2
sd2 = formats.SampleData.from_tree_sequence(tables.tree_sequence())
with self.assertRaises(ValueError):
sd1.merge(sd2)
def test_merge_overlapping_sites(self):
ts = get_example_ts(4, 10, 1, random_seed=1)
sd1 = formats.SampleData.from_tree_sequence(ts)
tables = ts.dump_tables()
# Change the position of the first and last sites to we have
# overhangs at either side.
position = tables.sites.position
position[0] += 1e-8
position[-1] -= 1e-8
tables.sites.position = position
ts = tables.tree_sequence()
sd2 = formats.SampleData.from_tree_sequence(ts)
self.assertEqual(
len(set(sd1.sites_position) & set(sd2.sites_position)), sd1.num_sites - 2
)
self.verify(sd1, sd2)
self.verify(sd2, sd1)
def run_perfect_inference(
base_ts,
num_threads=1,
path_compression=False,
extended_checks=True,
time_chunking=True,
progress_monitor=None,
use_ts=False,
engine=constants.C_ENGINE,
):
"""
Runs the perfect inference process on the specified tree sequence.
"""
ts = insert_perfect_mutations(base_ts)
sample_data = formats.SampleData.from_tree_sequence(ts)
if use_ts:
# Use the actual tree sequence that was provided as the basis for copying.
ancestors_ts = make_ancestors_ts(sample_data, ts, remove_leaves=True)
else:
ancestor_data = formats.AncestorData(sample_data)
build_simulated_ancestors(
sample_data, ancestor_data, ts, time_chunking=time_chunking
)
ancestor_data.finalise()
ancestors_ts = inference.match_ancestors(
sample_data,
ancestor_data,
engine=engine,
path_compression=path_compression,