How to use the tsinfer.formats.SampleData.from_tree_sequence function in tsinfer

To help you get started, we’ve selected a few tsinfer examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_from_historical_tree_sequence(self):
        sample_times = np.arange(10)
        ts = get_example_historical_sampled_ts(sample_times, 10)
        sd1 = formats.SampleData(sequence_length=ts.sequence_length)
        self.verify_data_round_trip(ts, sd1)
        sd2 = formats.SampleData.from_tree_sequence(ts)
        self.assertTrue(sd1.data_equal(sd2))
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_no_arguments(self):
        ts = get_example_ts(10, 10, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        # No arguments gives the same data
        subset = sd1.subset()
        sd1.assert_data_equal(subset)
        subset = sd1.subset(individuals=np.arange(sd1.num_individuals))
        sd1.assert_data_equal(subset)
        subset = sd1.subset(sites=np.arange(sd1.num_sites))
        sd1.assert_data_equal(subset)
        self.assertEqual(subset.num_provenances, sd1.num_provenances + 1)
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_merge_identical(self):
        n = 10
        ts = get_example_ts(n, 10, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        sd2 = sd1.merge(sd1)
        self.assertEqual(sd2.num_sites, sd1.num_sites)
        self.assertEqual(sd2.num_samples, 2 * sd1.num_samples)
        for var1, var2 in zip(sd1.variants(), sd2.variants()):
            self.assertEqual(var1.site, var2.site)
            self.assertTrue(np.array_equal(var1.genotypes, var2.genotypes[:n]))
            self.assertTrue(np.array_equal(var1.genotypes, var2.genotypes[n:]))
        self.verify(sd1, sd1)
        self.verify(sd2, sd1)
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_file_kwargs(self):
        # Make sure we pass kwards on to the SampleData constructor as
        # required.
        ts = get_example_ts(10, 10, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        with tempfile.TemporaryDirectory() as tmpdir:
            path = os.path.join(tmpdir, "sample-data")
            sd2 = sd1.merge(sd1, path=path)
            self.assertTrue(os.path.exists(path))
            sd3 = formats.SampleData.load(path)
            self.assertTrue(sd2.data_equal(sd3))
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
tables.mutations.add_row(site=0, node=node, derived_state=str(i + 2))
        # Create < 2 alleles by adding a non-variable site at the end
        extra_last_pos = (ts.site(ts.num_sites - 1).position + ts.sequence_length) / 2
        tables.sites.add_row(position=extra_last_pos, ancestral_state="0")
        tables.sort()
        tables.build_index()
        tables.compute_mutation_parents()
        ts = tables.tree_sequence()
        self.assertGreater(len(ts.site(0).mutations), 1)
        self.assertEqual(len(ts.site(ts.num_sites - 1).mutations), 0)
        sd1 = formats.SampleData(sequence_length=ts.sequence_length)
        self.verify_data_round_trip(ts, sd1)
        num_alleles = sd1.num_alleles()
        for var in ts.variants():
            self.assertEqual(len(var.alleles), num_alleles[var.site.id])
        sd2 = formats.SampleData.from_tree_sequence(ts)
        self.assertTrue(sd1.data_equal(sd2))
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_errors(self):
        ts = get_example_ts(10, 10, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        with self.assertRaises(ValueError):
            sd1.subset(sites=[])
        with self.assertRaises(ValueError):
            sd1.subset(individuals=[])
        # Individual IDs out of bounds

        with self.assertRaises(ValueError):
            sd1.subset(individuals=[-1, 0, 1])
        with self.assertRaises(ValueError):
            sd1.subset(individuals=[10, 0, 1])
        # Site IDs out of bounds
        with self.assertRaises(ValueError):
            sd1.subset(sites=[-1, 0, 1])
        with self.assertRaises(ValueError):
            sd1.subset(sites=[ts.num_sites, 0, 1])
        # Duplicate IDs
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_different_alleles_same_sites(self):
        ts = get_example_individuals_ts_with_metadata(5, 2, 10, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        tables = ts.dump_tables()
        tables.mutations.derived_state += 1
        sd2 = formats.SampleData.from_tree_sequence(tables.tree_sequence())
        self.verify(sd1, sd2)
        self.verify(sd2, sd1)
        sd3 = sd1.merge(sd2)
        for var in sd3.variants():
            self.assertEqual(var.site.alleles, ("0", "1", "2"))
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_mismatch_ancestral_state(self):
        # Difference ancestral states
        ts = get_example_ts(2, 2, 1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        tables = ts.dump_tables()
        tables.sites.ancestral_state += 2
        sd2 = formats.SampleData.from_tree_sequence(tables.tree_sequence())
        with self.assertRaises(ValueError):
            sd1.merge(sd2)
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_merge_overlapping_sites(self):
        ts = get_example_ts(4, 10, 1, random_seed=1)
        sd1 = formats.SampleData.from_tree_sequence(ts)
        tables = ts.dump_tables()
        # Change the position of the first and last sites to we have
        # overhangs at either side.
        position = tables.sites.position
        position[0] += 1e-8
        position[-1] -= 1e-8
        tables.sites.position = position
        ts = tables.tree_sequence()
        sd2 = formats.SampleData.from_tree_sequence(ts)
        self.assertEqual(
            len(set(sd1.sites_position) & set(sd2.sites_position)), sd1.num_sites - 2
        )
        self.verify(sd1, sd2)
        self.verify(sd2, sd1)
github tskit-dev / tsinfer / tsinfer / eval_util.py View on Github external
def run_perfect_inference(
    base_ts,
    num_threads=1,
    path_compression=False,
    extended_checks=True,
    time_chunking=True,
    progress_monitor=None,
    use_ts=False,
    engine=constants.C_ENGINE,
):
    """
    Runs the perfect inference process on the specified tree sequence.
    """
    ts = insert_perfect_mutations(base_ts)
    sample_data = formats.SampleData.from_tree_sequence(ts)

    if use_ts:
        # Use the actual tree sequence that was provided as the basis for copying.
        ancestors_ts = make_ancestors_ts(sample_data, ts, remove_leaves=True)
    else:
        ancestor_data = formats.AncestorData(sample_data)
        build_simulated_ancestors(
            sample_data, ancestor_data, ts, time_chunking=time_chunking
        )
        ancestor_data.finalise()

        ancestors_ts = inference.match_ancestors(
            sample_data,
            ancestor_data,
            engine=engine,
            path_compression=path_compression,