Training a DeepDive Model#

This tutorial demonstrates how to train a DeepDive model on single-cell ATAC-seq data, monitor training progress, and perform reconstruction and counterfactual prediction.


Imports#

!pip install scanpy seaborn DeepDive
[1]:
import scanpy as sc
import DeepDive
[2]:
from utils import reads_to_fragments

1. Load and preprocess the dataset#

We start with an AnnData object containing single-cell chromatin accessibility profiles. Here we use a liver dataset (sciatac3_liver_10k.h5ad). It contains the 10k cells sampled from the sciATAC-seq3 dataset.

[3]:
adata = sc.read_h5ad('data/sciatac3_liver_10k.h5ad')
[4]:
min_cells = int(adata.shape[0] * 0.01)
sc.pp.filter_genes(adata, min_cells=min_cells)
[5]:
adata.obs_names_make_unique()
reads_to_fragments(adata)
adata.X = adata.layers['fragments']

  • Gene filtering: Remove features observed in fewer than 1% of cells.

  • Read conversion: Convert raw read counts into approximate fragment counts using the provided utility function reads_to_fragments.

At this point, adata contains a filtered, fragment-based representation suitable for training.

2. Define model and training parameters#

DeepDive consists of a conditional variational autoencoder with adversarial disentanglement.

Model parameters

  • n_decoders: Number of decoders used for reconstruction.

  • n_epochs_pretrain_ae: Number of epochs to pretrain the autoencoder before adversarial training begins.

Training parameters

  • max_epoch: Maximum number of training epochs.

  • batch_size: Number of cells per training batch.

  • shuffle: Whether to shuffle the training dataset each epoch.

[6]:
n_decoders = 1
model_params = {
    'n_epochs_pretrain_ae' : 200*n_decoders,
    'n_decoders' : n_decoders,
}
train_params = {
    'max_epoch' : 300*n_decoders,
    'batch_size' : 1024,
    'shuffle' : True
}

3. Specify covariates#

DeepDive disentangles known covariates (both discrete and continuous) from biological variation. These are specified as column names in adata.obs.

[7]:
discrete_covriate_keys = ['sample_name', 'sex', 'batch', 'cell_type']
continuous_covriate_keys = ['day_of_pregnancy']

4. Initialize and train the model#

The model is initialized with the AnnData object and covariate specifications, then trained using the defined parameters.

[8]:
model = DeepDive.DeepDive(adata = adata,
                        discrete_covariate_names = discrete_covriate_keys,
                        continuous_covariate_names = continuous_covriate_keys,
                        **model_params
                       )
[9]:
model.train_model(adata, None,
                  **train_params)
Epoch Train [300 / 300]: 100%|██████████| 10/10 [00:00<00:00, 10.08it/s, ETA=01d:00h:04:m22s|01d:00h:04:m22s, kl_loss=3.04, recon_loss=3.19e+3]

Training history (loss curves) can be visualized with:

[10]:
DeepDive.plot_history(model)
../_images/tutorials_Training_18_0.png

5. Save and reload the trained model#

Trained models can be saved to disk and reloaded for downstream analysis.

[11]:
model.save('model')
DeepDIVE model saved at: model
[12]:
model = DeepDive.DeepDive(adata = adata,
                        discrete_covariate_names = discrete_covriate_keys,
                        continuous_covariate_names = continuous_covriate_keys,
                        **model_params
                       )
model = model.load(adata, 'model')

6. Reconstruction and counterfactual prediction#

DeepDive can reconstruct the input data from latent representations and predict counterfactuals by altering covariates.

Reconstruction (with all covariates):

[13]:
recon = model.predict(adata)

Counterfactual reconstruction (e.g., omitting covariates):

[14]:
recon_subset = model.predict(adata, covars_to_add = ['cell_type'])

7. Getting latent representation#

Latent representations from DeepDive can be computed using the get_latent method. To compute latent representations of only the residuals, an empty list can be passed to covars_to_add.

[15]:
residual = model.get_latent(adata, covars_to_add = [])
[16]:
sc.pp.neighbors(residual)
sc.tl.umap(residual)
[17]:
sc.pl.umap(residual, color = ['sample_name', 'cell_type'])
../_images/tutorials_Training_29_0.png

Summary#

This notebook covers the complete workflow of:

Preparing a dataset for DeepDive

  • Defining model and training parameters

  • Specifying covariates to disentangle

  • Training and monitoring the model

  • Saving and reloading models

  • Performing reconstruction and counterfactual prediction

[ ]: