How to use the anndata.AnnData function in anndata

To help you get started, we’ve selected a few anndata examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github theislab / diffxpy / diffxpy / unit_test / test_data_types.py View on Github external
def _test_anndata_raw(self, sparse):
        data, sample_description = self.simulate()
        gene_names = ["gene" + str(i) for i in range(data.shape[1])]
        if sparse:
            data = scipy.sparse.csr_matrix(data)

        data = anndata.AnnData(data)
        data.var_names = gene_names
        data.raw = data
        self._test_wald(data=data.raw, sample_description=sample_description)
        self._test_lrt(data=data.raw, sample_description=sample_description)
        self._test_t_test(data=data, sample_description=sample_description)
        self._test_rank(data=data, sample_description=sample_description)
github theislab / scgen / tests / test_plotting.py View on Github external
def test_reg_mean_plot():
    train = sc.read("./tests/data/train.h5ad", backup_url="https://goo.gl/33HtVh")
    network = scgen.VAEArith(x_dimension=train.shape[1], model_path="../models/test")
    network.train(train_data=train, n_epochs=0)
    unperturbed_data = train[((train.obs["cell_type"] == "CD4T") & (train.obs["condition"] == "control"))]
    condition = {"ctrl": "control", "stim": "stimulated"}
    pred, delta = network.predict(adata=train, adata_to_predict=unperturbed_data, conditions=condition,
                                  condition_key="condition",cell_type_key="cell_type")
    pred_adata = anndata.AnnData(pred, obs={"condition": ["pred"] * len(pred)}, var={"var_names": train.var_names})
    CD4T = train[train.obs["cell_type"] == "CD4T"]
    all_adata = CD4T.concatenate(pred_adata)
    scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred"},
                                 path_to_save="tests/reg_mean1.pdf")
    scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred"},
                                 path_to_save="tests/reg_mean2.pdf",  gene_list=["ISG15", "CD3D"])
    scgen.plotting.reg_mean_plot(all_adata,condition_key="condition", axis_keys={"x": "control", "y": "pred", "y1": "stimulated"},
                                 path_to_save="tests/reg_mean3.pdf")
    scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred", "y1": "stimulated"},
                                 gene_list=["ISG15", "CD3D"], path_to_save="tests/reg_mean.pdf",)
    network.sess.close()
github theislab / trVAE / tests / test_trVAEMulti.py View on Github external
if adata_source.shape[0] == 0:
        adata_source = pred_adatas.copy()[pred_adatas.obs[condition_key] == source_condition]

    if adata_target.shape[0] == 0:
        adata_target = pred_adatas.copy()[pred_adatas.obs[condition_key] == target_condition]

    source_labels = np.zeros(adata_source.shape[0]) + source_label
    target_labels = np.zeros(adata_source.shape[0]) + target_label

    pred_target = network.predict(adata_source,
                                  encoder_labels=source_labels,
                                  decoder_labels=target_labels,
                                  size_factor=adata_source.obs['size_factors'].values
                                  )

    pred_adata = anndata.AnnData(X=pred_target)
    pred_adata.obs[condition_key] = [name] * pred_target.shape[0]
    pred_adata.var_names = adata.var_names

    if sparse.issparse(adata_source.X):
        adata_source.X = adata_source.X.A

    if sparse.issparse(adata_target.X):
        adata_target.X = adata_target.X.A

    if sparse.issparse(pred_adata.X):
        pred_adata.X = pred_adata.X.A

    # adata_to_plot = pred_adata.concatenate(adata_target)

    # trvae.plotting.reg_mean_plot(adata_to_plot,
    #                              top_100_genes=top_100_genes,
github theislab / scgen / tests / test_mmd_cvae.py View on Github external
show=False)

        decoded_latent_with_true_labels = network.predict(data=latent_with_true_labels, encoder_labels=true_labels,
                                                          decoder_labels=true_labels, data_space='latent')

        cell_type_data = train[train.obs[cell_type_key] == cell_type]
        unperturbed_data = train[((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == ctrl_key))]
        true_labels = np.zeros((len(unperturbed_data), 1))
        fake_labels = np.ones((len(unperturbed_data), 1))

        sc.tl.rank_genes_groups(cell_type_data, groupby=condition_key, n_genes=100)
        diff_genes = cell_type_data.uns["rank_genes_groups"]["names"][stim_key]
        # cell_type_data = cell_type_data.copy()[:, diff_genes.tolist()]

        pred = network.predict(data=unperturbed_data, encoder_labels=true_labels, decoder_labels=fake_labels)
        pred_adata = anndata.AnnData(pred, obs={condition_key: ["pred"] * len(pred)},
                                     var={"var_names": cell_type_data.var_names})
        all_adata = cell_type_data.concatenate(pred_adata)

        scgen.plotting.reg_mean_plot(all_adata, condition_key=condition_key,
                                     axis_keys={"x": ctrl_key, "y": stim_key, "y1": "pred"},
                                     gene_list=diff_genes,
                                     path_to_save=f"./figures/reg_mean_{z_dim}.pdf")
        scgen.plotting.reg_var_plot(all_adata, condition_key=condition_key,
                                    axis_keys={"x": ctrl_key, "y": stim_key, 'y1': "pred"},
                                    gene_list=diff_genes,
                                    path_to_save=f"./figures/reg_var_{z_dim}.pdf")

        sc.pp.neighbors(all_adata)
        sc.tl.umap(all_adata)
        sc.pl.umap(all_adata, color=condition_key,
                   save="pred")
github KrishnaswamyLab / graphtools / graphtools / base.py View on Github external
self._check_data(data)
        n_pca, rank_threshold = self._parse_n_pca_threshold(data, n_pca, rank_threshold)
        try:
            if isinstance(data, pd.SparseDataFrame):
                data = data.to_coo()
            elif isinstance(data, pd.DataFrame):
                try:
                    data = data.sparse.to_coo()
                except AttributeError:
                    data = np.array(data)
        except NameError:
            # pandas not installed
            pass

        try:
            if isinstance(data, anndata.AnnData):
                data = data.X
        except NameError:
            # anndata not installed
            pass
        self.data = data
        self.n_pca = n_pca
        self.rank_threshold = rank_threshold
        self.random_state = random_state
        self.data_nu = self._reduce_data()
        super().__init__(**kwargs)
github theislab / dca / dca / api.py View on Github external
If `return_info` is true, all estimated distribution parameters are stored in AnnData such as:

    - `.obsm["X_dca_dropout"]` which is the mixture coefficient (pi) of the zero component
    in ZINB, i.e. dropout probability. (Only if ae_type is zinb or zinb-conddisp)

    - `.obsm["X_dca_dispersion"]` which is the dispersion parameter of NB.

    - `.uns["dca_loss_history"]` which stores the loss history of the training.

    Finally, the raw counts are stored as `.raw`.

    If `return_model` is given, trained model is returned. When both `copy` and `return_model`
    are true, a tuple of anndata and model is returned in that order.
    """

    assert isinstance(adata, anndata.AnnData), 'adata must be an AnnData instance'
    assert mode in ('denoise', 'latent'), '%s is not a valid mode.' % mode

    # set seed for reproducibility
    random.seed(random_state)
    np.random.seed(random_state)
    tf.set_random_seed(random_state)
    os.environ['PYTHONHASHSEED'] = '0'

    # this creates adata.raw with raw counts and copies adata if copy==True
    adata = read_dataset(adata,
                         transpose=False,
                         test_split=False,
                         copy=copy)

    # check for zero genes
    nonzero_genes, _ = sc.pp.filter_genes(adata.X, min_counts=1)
github soedinglab / prosstt / examples / generate_simN.py View on Github external
def plot_diff_map(X, pseudotime, brns):
    data = ad.AnnData(X)
    diffmap(adata=data)
    diff_map = data.obsm["X_diffmap"]
    cols = np.array(list(pylab.cm.Set1.colors))

    branch_names, indices = np.unique(brns, return_inverse=True)

    fig, ax = plt.subplots(ncols=2)
    fig.set_size_inches(w=9, h=4)
    ax[0].scatter(diff_map[:, 0], diff_map[:, 1], c=cols[indices])
    ax[0].set_title("branches")
    ax[1].scatter(diff_map[:, 0], diff_map[:, 1], c=pseudotime, cmap="viridis")
    ax[1].set_title("pseudotime")
    plt.show()
github brianhie / geosketch / geosketch / sketch.py View on Github external
def louvain(X, N, resolution=1, seed=None, replace=False):
    from anndata import AnnData
    import scanpy.api as sc

    adata = AnnData(X=X)
    sc.pp.neighbors(adata, use_rep='X')
    sc.tl.louvain(adata, resolution=resolution, key_added='louvain')
    cluster_labels_full = adata.obs['louvain'].tolist()

    louv = {}
    for i, cluster in enumerate(cluster_labels_full):
        if cluster not in louv:
            louv[cluster] = []
        louv[cluster].append(i)

    lv_idx = []
    for n in range(N):
        louv_cells = list(louv.keys())
        louv_cell = louv_cells[np.random.choice(len(louv_cells))]
        samples = list(louv[louv_cell])
        sample = samples[np.random.choice(len(samples))]
github theislab / scgen / scgen / models / util.py View on Github external
corrected = anndata.AnnData(network.reconstruct(all_shared_ann.X, use_data=True))
        corrected.obs = all_shared_ann.obs.copy(deep=True)
        corrected.var_names = adata.var_names.tolist()
        corrected = corrected[adata.obs_names]
        if adata.raw is not None:
            adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
            adata_raw.obs_names = adata.obs_names
            corrected.raw = adata_raw
        corrected.obsm["latent"] = all_shared_ann.X
        return corrected
    else:
        all_not_shared_ann = anndata.AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
        all_corrected_data = anndata.AnnData.concatenate(all_shared_ann, all_not_shared_ann, batch_key="concat_batch", index_unique=None)
        if "concat_batch" in all_shared_ann.obs.columns:
            del all_corrected_data.obs["concat_batch"]
        corrected = anndata.AnnData(network.reconstruct(all_corrected_data.X, use_data=True))
        corrected.obs = pd.concat([all_shared_ann.obs, all_not_shared_ann.obs])
        corrected.var_names = adata.var_names.tolist()
        corrected = corrected[adata.obs_names]
        if adata.raw is not None:
            adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
            adata_raw.obs_names = adata.obs_names
            corrected.raw = adata_raw
        corrected.obsm["latent"] = all_corrected_data.X
        return corrected
github theislab / trVAE / trvae / models / _trvae.py View on Github external
)
            encoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
            decoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
            pred_adata = network.predict(train_adata, encoder_labels, decoder_labels)
            ```
        """
        adata = remove_sparsity(adata)

        encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
        decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)

        reconstructed = self.trvae_model.predict([adata.X, encoder_labels, decoder_labels])[0]
        reconstructed = np.nan_to_num(reconstructed)

        if return_adata:
            output = anndata.AnnData(X=reconstructed)
            output.obs = adata.obs.copy(deep=True)
            output.var_names = adata.var_names
        else:
            output = reconstructed

        return output