Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def verify(self, ts):
n = ts.num_samples
self.assertGreater(ts.num_sites, 2)
sd = tsinfer.SampleData.from_tree_sequence(ts)
ts1 = tsinfer.infer(sd, simplify=True)
# When simplify is true the samples should be zero to n.
self.assertEqual(list(ts1.samples()), list(range(n)))
for tree in ts1.trees():
self.assertEqual(tree.num_samples(), len(list(tree.leaves())))
# When simplify is true and there is no path compression,
# the samples should be zero to N - n up to n
ts2 = tsinfer.infer(sd, simplify=False, path_compression=False)
self.assertEqual(
list(ts2.samples()), list(range(ts2.num_nodes - n, ts2.num_nodes))
)
# Check that we're calling simplify with the correct arguments.
ts2 = tsinfer.infer(sd, simplify=False).simplify(keep_unary=True)
t1 = ts1.dump_tables()
t2 = ts2.dump_tables()
t1.provenances.clear()
t2.provenances.clear()
self.assertEqual(t1, t2)
def test_inferred_random_data(self):
np.random.seed(10)
num_sites = 40
num_samples = 8
G = np.random.randint(2, size=(num_sites, num_samples)).astype(np.int8)
with tsinfer.SampleData() as sample_data:
for j in range(num_sites):
sample_data.add_site(j, G[j])
ts = tsinfer.infer(sample_data)
self.verify(ts)
def test_infer(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
self.assertGreater(ts.num_sites, 1)
samples = tsinfer.SampleData.from_tree_sequence(ts)
inferred_ts = tsinfer.infer(samples)
self.validate_ts(inferred_ts)
def verify(self, sample_data):
with self.assertLogs("tsinfer.inference", level="INFO") as logs:
ts = tsinfer.infer(sample_data)
messages = [record.msg for record in logs.records]
self.assertIn("Inserting detailed site information", messages)
tsinfer.verify(sample_data, ts)
return ts
sites_by_samples = np.array(
[
[u, 1, 1, 1, 0], # Site 0
[u, 1, 1, 0, 0], # Site 1
[u, 0, 0, 1, 0], # Site 2
[u, 0, 1, 1, 0], # Site 3
],
dtype=np.int8,
)
expected = sites_by_samples.copy()
expected[:, 0] = [0, 0, 0, 0]
with tsinfer.SampleData() as sample_data:
for row in range(sites_by_samples.shape[0]):
sample_data.add_site(row, sites_by_samples[row, :])
for e in [tsinfer.PY_ENGINE, tsinfer.C_ENGINE]:
ts = tsinfer.infer(sample_data, engine=e)
self.assertEquals(ts.num_trees, 2)
self.assertTrue(np.all(expected == ts.genotype_matrix()))
def verify(self, ts):
n = ts.num_samples
self.assertGreater(ts.num_sites, 2)
sd = tsinfer.SampleData.from_tree_sequence(ts)
ts1 = tsinfer.infer(sd, simplify=True)
# When simplify is true the samples should be zero to n.
self.assertEqual(list(ts1.samples()), list(range(n)))
for tree in ts1.trees():
self.assertEqual(tree.num_samples(), len(list(tree.leaves())))
# When simplify is true and there is no path compression,
# the samples should be zero to N - n up to n
ts2 = tsinfer.infer(sd, simplify=False, path_compression=False)
self.assertEqual(
list(ts2.samples()), list(range(ts2.num_nodes - n, ts2.num_nodes))
)
# Check that we're calling simplify with the correct arguments.
ts2 = tsinfer.infer(sd, simplify=False).simplify(keep_unary=True)
t1 = ts1.dump_tables()
t2 = ts2.dump_tables()
def verify(self, sample_data, position_subset):
full_ts = tsinfer.infer(sample_data)
subset_ts = self.subset_sites(full_ts, position_subset)
ancestor_data = tsinfer.generate_ancestors(sample_data)
ancestors_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
subset_ancestors_ts = tsinfer.minimise(
self.subset_sites(ancestors_ts, position_subset)
)
subset_ancestors_ts = subset_ancestors_ts.simplify()
subset_sample_data = tsinfer.SampleData.from_tree_sequence(subset_ts)
output_ts = tsinfer.match_samples(subset_sample_data, subset_ancestors_ts)
self.assertTrue(
np.array_equal(output_ts.genotype_matrix(), subset_ts.genotype_matrix())
)
def test_partial_samples(self):
sd = tsinfer.SampleData.from_tree_sequence(
msprime.simulate(
10, mutation_rate=2, recombination_rate=2, random_seed=233
),
use_times=False,
)
ts1 = tsinfer.infer(sd)
ancestors = tsinfer.generate_ancestors(sd)
ancestors_ts = tsinfer.match_ancestors(sd, ancestors)
# test indices missing from start, end, and in the middle
for subset in (np.arange(8), np.arange(2, 10), np.arange(5) * 2):
t1 = ts1.simplify(subset).dump_tables()
t1.provenances.clear()
t2 = tsinfer.match_samples(sd, ancestors_ts, indexes=subset).dump_tables()
t2.simplify()
t2.provenances.clear()
self.assertEqual(t1, t2)
all_metadata = []
for j in range(ts.num_samples):
metadata = {str(j): random_string(rng) for j in range(rng.randint(0, 5))}
sample_data.add_population(metadata=metadata)
all_metadata.append(metadata)
for j in range(ts.num_samples):
sample_data.add_individual(population=j)
for variant in ts.variants():
sample_data.add_site(
variant.site.position, variant.genotypes, variant.alleles
)
sample_data.finalise()
for j, metadata in enumerate(sample_data.populations_metadata[:]):
self.assertEqual(all_metadata[j], metadata)
output_ts = tsinfer.infer(sample_data)
output_metadata = [
json.loads(population.metadata.decode())
for population in output_ts.populations()
]
self.assertEqual(all_metadata, output_metadata)
for j, sample in enumerate(output_ts.samples()):
node = output_ts.node(sample)
self.assertEqual(node.population, j)