Using PPC for model evaluation#

What is PPC?#

Posterior Predictive checks (PPC) are a way to validate the goodness of fit of your generative models by computing metrics on reconstructed counts and on the raw counts and comparing the results. Samples are taken from the posterior predictive distribution: \(p(\hat{x} \mid x)\).

You can build a better intuition for it by reading more here and here.

Imports#

import scvi
import anndata
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from scvi_criticism import run_ppc, PPC, PPCPlot

Get the data and model#

Here we use the data and pre-trained model obtained from running this scvi-tools tutorial.

The dataset used is a subset of the heart cell atlas dataset:
Litviňuková, M., Talavera-López, C., Maatz, H., Reichart, D., Worth, C. L., Lindberg, E. L., … & Teichmann, S. A. (2020). Cells of the adult human heart. Nature, 588(7838), 466-472.

If you have not yet pre-trained the model uncomment and run the below to pre-train the model and save it locally:

# adata = scvi.data.heart_cell_atlas_subsampled()
# sc.pp.filter_genes(adata, min_counts=3)
# adata.layers["counts"] = adata.X.copy()
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)
# adata.raw = adata
# sc.pp.highly_variable_genes(
#     adata,
#     n_top_genes=1200,
#     subset=True,
#     layer="counts",
#     flavor="seurat_v3",
#     batch_key="cell_source"
# )
# scvi.model.SCVI.setup_anndata(
#     adata,
#     layer="counts",
#     categorical_covariate_keys=["cell_source", "donor"],
#     continuous_covariate_keys=["percent_mito", "percent_ribo"]
# )
# model = scvi.model.SCVI(adata)
# model.train()
# model.save("local/hca/")

Run this to load the model:

model_path = "local/hca"
model = scvi.model.SCVI.load(model_path)
INFO     File local/hca/model.pt already downloaded                                                                
model
SCVI Model with the following params: 
n_hidden: 128, n_latent: 10, n_layers: 1, dropout_rate: 0.1, dispersion: gene, gene_likelihood: zinb, 
latent_distribution: normal
Training status: Trained

model.adata
AnnData object with n_obs × n_vars = 18641 × 1200
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used', '_scvi_batch', '_scvi_labels'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: '_scvi_manager_uuid', '_scvi_uuid', 'cell_type_colors', 'hvg', 'log1p'
    obsm: '_scvi_extra_categorical_covs', '_scvi_extra_continuous_covs'
    layers: 'counts'

Overview of scvi-criticism#

There are two classes that you can use with scvi-criticism. One is the PPC class and the other is the PPCPlot class.

scvi-criticism.PPC#

The PPC class can be used to compute posterior predictive samples and compute various metrics on them.
The figure below gives an overview of the input/outputs of this class.

Currently, three metrics are implemented:

We will see each of these in more detail below where we compute and plot the output of each of them.

The scvi-criticism package also provides a convenience function called run_ppc that you can call with your favorite metric that you want to compute for a given model. It will take care of much of the boilerplate and call the right PPC methods so you don’t have to.

def run_ppc(
    adata: AnnData,
    model,
    metric: str,
    n_samples: int,
    layer: Optional[str] = None,
    custom_indices: Optional[Union[int, Sequence[int]]] = None,
    **metric_specific_kwargs,
)

scvi-criticism.PPCPlot#

As its name indicates, the scvi-criticism.PPCPlot class can be used to draw various plots displaying metrics computed by the PPC class.
The PPCPlot class takes a single required argument which is the instance of PPC that holds your computed metrics.

We’ll see concrete examples of how to use PPCPlot below where we compute different metrics on our data.

PPC + coefficient of variation#

First, let’s run PPC and use the coefficient of variation as metric.

Below is the code to do so, with step by step explanations. We’ll later see that we can skip most of this with a simple call to run_ppc.

# pick the indices you want to use. here we use all observations
indices = np.arange(model.adata.n_obs)

# get the raw (counts) data from .X or a layer if applicable
raw_data = model.adata[indices].layers["counts"]

# create PPC instance
n_samples = 5
ppc = PPC(n_samples=n_samples, raw_counts=raw_data)

# define your models dictionary. here we use only one model
model_name = f"{model.__class__.__name__}"
models_dict = {model_name: model}

# compute posterior predictive samples for your models
ppc.store_posterior_predictive_samples(models_dict, indices=indices)

# calculate the CV metric
ppc.coefficient_of_variation(cell_wise=True)

Instead of the above, you can simply call run_ppc which will do exactly the same as above.

n_samples = 5
ppc = run_ppc(model.adata, model, "cv_cell", n_samples = n_samples, layer="counts")
ppc
--- Posterior Predictive Checks ---
n_samples = 5
raw_counts shape = (18641, 1200)
models: ['SCVI']
metrics: 
{
    "cv_cell": "Pandas DataFrame with shape=(18641, 2), columns=['SCVI', 'Raw']"
}

Let’s see what the cv_cell metric contains:

ppc.metrics["cv_cell"]
SCVI Raw
0 6.510572 9.820683
1 7.702765 8.686443
2 5.373195 4.587027
3 6.202974 6.262758
4 4.776725 5.975578
... ... ...
18636 6.560287 8.613652
18637 6.063679 5.561197
18638 9.950002 10.452272
18639 6.241374 7.699229
18640 7.458881 7.511846

18641 rows × 2 columns

This is a pandas DataFrame where each row is a cell. The “Raw” column contains the cell-wise coefficient of variation computed on the raw data. The “SCVI” column contains the same computed on the posterior predictive samples (averaged over the N posterior predictive samples). If our model fits the data well, according to this metric, we’d ideally want these values to be “similar”.

Let’s evaluate this similarity by:

  • plotting a scatterplot that disaplys these values, and comparing it (visually) to the identity line

  • computing the correlation of these vectors

This is where we can use the PPCPlot class.

ppc_plt = PPCPlot(ppc)
ppc_plt.plot_cv(model_name="SCVI", cell_wise=True)
INFO     model=SCVI | metric=cv_cell | n_cells=18641:                                                              
         Mean Absolute Error=0.87,                                                                                 
         Mean Squared Error=1.62                                                                                   
         Pearson correlation=0.80                                                                                  
         Spearman correlation=0.82                                                                                 
                                                                                                                   
../_images/b24916eaf7ce5616b30845563bc11a22c065f0fb9d2673443fa4a6b4f763ca6c.png

On the scatterplot above, each dot is a cell. The y axis represents the cell-wise coefficient of variation values computed for each cell on the raw data. The x axis represents the cell-wise coefficient of variation values computed for each cell on the posterior predictive samples (averaged over the N samples).

We also compute a few different correlation metrics, namely mean absolute/squared errors, and pearson/spearman correlations.

Let’s do the same but gene-wise:

ppc = run_ppc(model.adata, model, "cv_gene", n_samples = n_samples, layer="counts")
ppc
--- Posterior Predictive Checks ---
n_samples = 5
raw_counts shape = (18641, 1200)
models: ['SCVI']
metrics: 
{
    "cv_gene": "Pandas DataFrame with shape=(1200, 2), columns=['SCVI', 'Raw']"
}
ppc_plt = PPCPlot(ppc)
ppc_plt.plot_cv(model_name="SCVI", cell_wise=False)
INFO     model=SCVI | metric=cv_gene | n_cells=18641:                                                              
         Mean Absolute Error=1.36,                                                                                 
         Mean Squared Error=14.29                                                                                  
         Pearson correlation=0.95                                                                                  
         Spearman correlation=0.99                                                                                 
                                                                                                                   
../_images/26104213418f74cb3dd18a02aba77524619dbce11699d85060907c0babdb85b1.png

This time each dot on the scatterplot is a gene. As expected, there is fewer of them in this case.

PPC + DE#

In this section, we evaluate goodness of fit of the model by running differential expression on the raw data, then on the reconstructed counts and comparing the DE results.
Once again, we can use run_ppc to compute the metric for us in one call:

n_samples = 1
ppc = run_ppc(model.adata, model, "diff_exp", n_samples = n_samples, layer="counts", de_groupby="cell_type")
ppc
--- Posterior Predictive Checks ---
n_samples = 1
raw_counts shape = (18641, 1200)
models: ['SCVI']
metrics: 
{
    "diff_exp": {
        "adata_raw": "AnnData object with n_obs=18641, n_vars=1200",
        "var_names": {
            "Adipocytes": [
                "GPAM",
                "MGST1"
            ],
            "Atrial_Cardiomyocyte": [
                "MYL7",
                "RYR2"
            ],
            "Endothelial": [
                "VWF",
                "B2M"
            ],
            "Fibroblast": [
                "DCN",
                "NEGR1"
            ],
            "Lymphoid": [
                "PTPRC",
                "B2M"
            ],
            "Mesothelial": [
                "ITLN1",
                "PLA2G2A"
            ],
            "Myeloid": [
                "CD163",
                "CTSB"
            ],
            "Neuronal": [
                "NRXN1",
                "CDH19"
            ],
            "Pericytes": [
                "RGS5",
                "PLA2G5"
            ],
            "Smooth_muscle_cells": [
                "ACTA2",
                "MYH11"
            ],
            "Ventricular_Cardiomyocyte": [
                "RYR2",
                "SORBS2"
            ]
        },
        "SCVI": {
            "adata_approx": "AnnData object with n_obs=18641, n_vars=1200",
            "lfc_df_approx": "Pandas DataFrame with shape=(11, 22), first 5 columns=['GPAM', 'MGST1', 'MYL7', 'RYR2', 'VWF']",
            "lfc_mae": "Pandas Series with n_rows=11",
            "lfc_mae_mean": 0.7108566101240225,
            "lfc_pearson": "Pandas Series with n_rows=11",
            "lfc_pearson_mean": 0.9310277325971511,
            "lfc_spearman": "Pandas Series with n_rows=11",
            "lfc_spearman_mean": 0.9382091592617907,
            "fraction_df_approx": "Pandas DataFrame with shape=(11, 22), first 5 columns=['GPAM', 'MGST1', 'MYL7', 'RYR2', 'VWF']",
            "fraction_mae": "Pandas Series with n_rows=11",
            "fraction_mae_mean": 0.02348371559580257,
            "fraction_pearson": "Pandas Series with n_rows=11",
            "fraction_pearson_mean": 0.9893390784300053,
            "fraction_spearman": "Pandas Series with n_rows=11",
            "fraction_spearman_mean": 0.9751844878784215,
            "gene_comparisons": "Pandas DataFrame with shape=(11, 3), columns=['precision', 'recall', 'f1']"
        },
        "lfc_df_raw": "Pandas DataFrame with shape=(11, 22), first 5 columns=['GPAM', 'MGST1', 'MYL7', 'RYR2', 'VWF']",
        "fraction_df_raw": "Pandas DataFrame with shape=(11, 22), first 5 columns=['GPAM', 'MGST1', 'MYL7', 'RYR2', 'VWF']"
    }
}

There is quite a lot of information in the metrics this time, so let’s take a closer look:

  • ppc.metrics["diff_exp"].adata_raw is simply the raw counts data.

  • ppc.metrics["diff_exp"].var_names is a dictionary where keys are DE groups (for example cell types), and values are arrays containing the N top differentially expressed genes in that group. The N can be specified via n_top_genes in the call to run_ppc and defaults to 2.

  • ppc.metrics["diff_exp"].lfc_df_raw is a pandas DataFrame where rows are groups, and columns are the N top genes for all groups, i.e., n_cols = n_top_genes_per_group * n_groups.

    Each DataFrame cell (group,gene) contains: log2(mean_gene_expression_in_group / mean_gene_expression_not_in_group) for that gene, where:

    • mean_gene_expression_in_group is the average gene expression of that gene for all cells in that group

    • mean_gene_expression_not_in_group is the average gene expression of that gene for all other cells (not in that group)

    This is also called “1 vs all” logfoldchange (LFC).

  • ppc.metrics["diff_exp"].fraction_df_raw is a pandas DataFrame where rows and columns are the same as ppc.metrics["diff_exp"].lfc_df_raw. Each DataFrame cell (group,gene) contains the fraction of cells in that group that express that gene.

  • ppc.metrics["diff_exp"]["SCVI"] is a dictionary that contains the outputs of computing different metrics that are specific to the model called “SCVI”.

    • ppc.metrics["diff_exp"]["SCVI"]["adata_approx"] is the reconstructed count data.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_adata_approx"] is the same as lfc_df_raw but for the estimated count data.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_mae"] is the row-wise mean absolute error between lfc_df_raw and lfc_adata_approx.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_mae_mean"] is the mean of lfc_mae.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_pearson"] is the row-wise pearson correlation between lfc_df_raw and lfc_adata_approx.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_pearson_mean"] is the mean of lfc_pearson.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_spearman"] is the row-wise spearman correlation between lfc_df_raw and lfc_adata_approx.

    • ppc.metrics["diff_exp"]["SCVI"]["lfc_spearman_mean"] is the the mean of lfc_spearman.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_df_approx"] is the same as fraction_df_raw but for the estimated count data.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_mae"] is the same as lfc_mae but for gene expression fractions instead of lfc.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_mae_mean"] is the same as lfc_mae_mean but for gene expression fractions.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_pearson"] is the same as lfc_pearson but for gene expression fractions.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_pearson_mean"] is the same as lfc_pearson_mean but for gene expression fractions.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_spearman"] is the same as lfc_spearman but for gene expression fractions.

    • ppc.metrics["diff_exp"]["SCVI"]["fraction_spearman_mean"] is the same as lfc_spearman_mean but for gene expression fractions.

    • ppc.metrics["diff_exp"]["SCVI"]["gene_comparisons"] is a pandas DataFrame where rows are groups. Each row reports a precision, recall and F1 score that represents the overlap between the top N differentially expressed genes in that group on the raw data and the top N differentially expressed genes in that group on the estimated count data. The N can be specified via n_top_genes_overlap in the call to run_ppc and defaults to: min(adata_raw.n_vars, DEFAULT_DE_N_TOP_GENES_OVERLAP).

      This is how this overlap is calculated: for each group, we consider the unordered set of top ranked genes between raw and approx DE results. To do that we “binarize” the gene selections: we create two binary vectors (one for raw, one for approx) where a 1 in the vector means gene was selected. We then compute the precision, recall and F1 score between these two vectors. Below is a depiction of those notions in this context:

That’s a lot of text. Let’s visualize some of these metrics.

First, let’s look at the “1 vs all” LFC dotplot for the raw and approximated count data:

ppc_plt = PPCPlot(ppc)
ppc_plt.plot_diff_exp("SCVI")
../_images/2d29cc9610c142662818876d2acc71b0b8a924a0dd22e87b698a00be0132e1f1.png ../_images/f18b831ff1f84b9b9b98d548a8a8daa9cb9ea28fb539e54ace98cd944d8c82e0.png

We can also look at a subset of the genes:

ppc_plt.plot_diff_exp("SCVI", var_names_subset=["Ventricular_Cardiomyocyte"])
../_images/875cbc373250f2175b8a64939d6ff3f0cc2d68de1d9503ef5d474610bc52c311.png ../_images/64cafaa8a481a56f4c83b53d05a695f6c3191b9351d4089f0f8c0116c7001372.png

These all look very similar but it’s hard to evaluate that visually. Let’s look at the mean absolute error (MAE), pearson and spearman correlations between the two:

ppc_plt.plot_diff_exp("SCVI", plot_kind="lfc_comparisons")
INFO     LFC (1 vs all) gene expressions across groups:                                                            
         Mean Absolute Error=0.71,                                                                                 
         Pearson correlation=0.93                                                                                  
         Spearman correlation=0.94                                                                                 
../_images/c0faf37f758772d74b73b1c900ad7637b6939dc91fb1bc3496ca9209e6b98915.png
ppc_plt.plot_diff_exp("SCVI", plot_kind="fraction_comparisons")
INFO     fractions of genes expressed per group across groups:                                                     
         Mean Absolute Error=0.02,                                                                                 
         Pearson correlation=0.99                                                                                  
         Spearman correlation=0.98                                                                                 
../_images/7ad2c26cea76aa10b130b1b8e02e49f89ce2dc87bec050aa8d431a01ae990466.png

That looks pretty good.

Next, let’s look at the F1 score for the top N genes that are differentially expressed in each group:

ppc_plt.plot_diff_exp("SCVI", plot_kind="gene_overlaps")
../_images/e1a2da67846a63b9ddd358d65d649dd0f2d751a08f8e4ab3dcf41ad34f1ea5f3.png

Last but not least, let’s plot a summary. In the box plot below, dots are groups and the x axis represents the various metrics we computed and plotted separately earlier. The box plot shows the distribution of the metrics across groups.

ppc_plt.plot_diff_exp("SCVI", plot_kind="summary")
../_images/1b7e0850c8759e6377642046cd697d45dc7c03767a3bab46d0d6c8994483f8bf.png