How to use the tsinfer.SampleData function in tsinfer

To help you get started, we’ve selected a few tsinfer examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tskit-dev / tsinfer / tests / test_inference.py View on Github external
def verify_round_trip(self, genotypes, exclude_sites):
        self.assertEqual(genotypes.shape[0], exclude_sites.shape[0])
        with tsinfer.SampleData() as sample_data:
            for j in range(genotypes.shape[0]):
                sample_data.add_site(j, genotypes[j])
        exclude_positions = sample_data.sites_position[:][exclude_sites]
        for simplify in [False, True]:
            output_ts = tsinfer.infer(
                sample_data, simplify=simplify, exclude_positions=exclude_positions
            )
            for tree in output_ts.trees():
                for site in tree.sites():
                    inf_type = json.loads(site.metadata)["inference_type"]
                    if exclude_sites[site.id]:
                        self.assertEqual(inf_type, tsinfer.INFERENCE_FITCH_PARSIMONY)
                    else:
                        self.assertEqual(inf_type, tsinfer.INFERENCE_FULL)
                    f = np.sum(genotypes[site.id])
                    if f == 0:
github tskit-dev / tsinfer / tests / test_formats.py View on Github external
def test_zero_sequence_length(self):
        # Mangle a sample data file to force a zero sequence length.
        ts = msprime.simulate(10, mutation_rate=2, random_seed=5)
        with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
            filename = os.path.join(tempdir, "samples.tmp")
            with tsinfer.SampleData(path=filename) as sample_data:
                for var in ts.variants():
                    sample_data.add_site(var.site.position, var.genotypes)
            store = zarr.LMDBStore(filename, subdir=False)
            data = zarr.open(store=store, mode="w+")
            data.attrs["sequence_length"] = 0
            store.close()
            sample_data = tsinfer.load(filename)
            self.assertEqual(sample_data.sequence_length, 0)
            self.assertRaises(ValueError, tsinfer.generate_ancestors, sample_data)
github tskit-dev / tsinfer / tests / test_inference.py View on Github external
def test_match_ancestors_samples(self):
        with tsinfer.SampleData(sequence_length=2) as sample_data:
            sample_data.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
        ancestor_data = tsinfer.generate_ancestors(sample_data)
        # match_ancestors fails when samples unfinalised
        unfinalised = tsinfer.SampleData(sequence_length=2)
        unfinalised.add_site(1, genotypes=[0, 1, 1, 0], alleles=["G", "C"])
        self.assertRaises(
            ValueError, tsinfer.match_ancestors, unfinalised, ancestor_data
        )
github tskit-dev / tsinfer / tests / test_provenance.py View on Github external
def test_infer(self):
        ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
        self.assertGreater(ts.num_sites, 1)
        samples = tsinfer.SampleData.from_tree_sequence(ts)
        inferred_ts = tsinfer.infer(samples)
        self.validate_ts(inferred_ts)
github tskit-dev / tsinfer / tests / test_inference.py View on Github external
def test_large_random_data(self):
        n = 100
        m = 30
        G, positions = get_random_data_example(n, m)
        with tsinfer.SampleData(sequence_length=m) as sample_data:
            for genotypes, position in zip(G, positions):
                sample_data.add_site(position, genotypes)
        self.verify(sample_data)
github tskit-dev / tsinfer / evaluation.py View on Github external
def generate_samples(ts, error_param=0):
    """
    Generate a samples file from a simulated ts based on the empirically estimated
    error matrix saved in self.error_matrix.
    Reject any variants that result in a fixed column.
    """
    assert ts.num_sites != 0
    sd = tsinfer.SampleData(sequence_length=ts.sequence_length)
    try:
        e = float(error_param)
        for v in ts.variants():
            g = v.genotypes if error_param == 0 else make_errors(v.genotypes, e)
            sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g)
    except ValueError:
        error_matrix = pd.read_csv(error_param)
        # Error_param is not a number => is a error file
        # First record the allele frequency
        for v in ts.variants():
            m = v.genotypes.shape[0]
            frequency = np.sum(v.genotypes) / m
            # Find closest row in error matrix file
            closest_row = (error_matrix["freq"] - frequency).abs().argsort()[:1]
            closest_freq = error_matrix.iloc[closest_row]
            g = make_errors_genotype_model(v.genotypes, closest_freq)
github tskit-dev / tsinfer / convert_1kg.py View on Github external
def convert(
        vcf_file, pedigree_file, output_file, max_variants=None, show_progress=False):

    if max_variants is None:
        max_variants = 2**32  # Arbitrary, but > defined max for VCF

    with tsinfer.SampleData(path=output_file, num_flush_threads=2) as sample_data:
        pop_id_map = add_populations(sample_data)

        vcf = cyvcf2.VCF(vcf_file)
        individual_names = list(vcf.samples)
        vcf.close()

        with open(pedigree_file, "r") as ped_file:
            add_samples(ped_file, pop_id_map, individual_names, sample_data)

        for index, site in enumerate(variants(vcf_file, show_progress)):
            sample_data.add_site(
                position=site.position, genotypes=site.genotypes,
                alleles=site.alleles, metadata=site.metadata)
            if index == max_variants:
                break
        sample_data.record_provenance(command=sys.argv[0], args=sys.argv[1:])