4. Align new slices to the ABA brain atlas

Finally, we align a new Visium slice into the existing 3D ST brian atlas. We identify atlas slices close to the new slice based on the semantic distances, followed by attention-weighted averaging to predict the z-coordinate of the new one. The atlas slice with the closest z-coordinate is then selected as a spatial template to scale and align the x- and y-axes. At this stage, the 3D coordinates for every spot in the new slice have been obtained, allowing annotated atlas information (e.g., standard anatomical region labels) to be accurately propagated to the newly integrated slice.

Import packsges

[1]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt

from STAIR.emb_alignment import Emb_Align
from STAIR.loc_alignment import Loc_Align
from STAIR.utils import *

load data

[2]:
adata1 = sc.read('./data/brain_ST_all.h5ad')
adata2 = sc.read('./data/brain_10X.h5ad')
adata1.obs_names_make_unique()
adata1.var_names_make_unique()
adata2.obs_names_make_unique()
adata2.var_names_make_unique()
[3]:
adata1 = adata1[~adata1.obs['ABA_id'].isin([0,997])].copy()
adata1.obs['section_index'] = adata1.obs['section_index'].astype('str')
adata1.obs['section_index'] = 'ST_' + adata1.obs['section_index']
adata1.obs['section_index'] = adata1.obs['section_index'].astype('category')

adata = ad.concat([adata1, adata2], join='outer')
gene_use = list(set(adata1.var_names.tolist()).intersection(set(adata2.var_names.tolist())))
adata = adata[:,gene_use].copy()
adata.obs['section_index'] = adata.obs['section_index'].cat.add_categories(['Visium'])
adata.obs['section_index'] = adata.obs['section_index'].fillna('Visium')
adata.obsm['spatial_ccf_2d'] = adata.obs[['stereo_ML', 'stereo_DV']].values
adata.obsm['spatial_ccf_3d'] = adata.obs[['stereo_ML', 'stereo_DV', 'stereo_AP']].values

adata.obs['sample'] = adata.obs['animal'].astype(str).replace({'A1':'ST_A1', 'A2':'ST_A2', 'A3':'ST_A3', 'nan':'Visium'})
adata.obs['sample']

adata
[3]:
AnnData object with n_obs × n_vars = 35141 × 14182
    obs: 'section_index', 'stereo_ML', 'stereo_DV', 'stereo_AP', 'HE_X', 'HE_Y', 'ABA_acronym', 'ABA_name', 'ABA_parent', 'nuclei_segmented', 'spot_radius', 'passed_QC', 'cluster_id', 'cluster_name', 'animal', 'ABA_id', 'ABA_name_level2', 'ABA_name_level1', 'ABA_color_level2', 'ABA_color_level1', 'in_tissue', 'array_row', 'array_col', 'sample'
    obsm: 'spatial', 'spatial_3d', 'spatial_ccf_2d', 'spatial_ccf_3d'
    layers: 'counts'
[4]:
# construct result path
result_path = construct_folder('10X_brian')
keys_use = adata.obs[['stereo_AP', 'section_index']].drop_duplicates().sort_values('stereo_AP', ascending=False)['section_index'].tolist()

Preprocess

[5]:
emb_align = Emb_Align(adata, batch_key='section_index', hvg=3000, result_path=result_path, device = 'cuda:0')
emb_align.prepare(count_key='counts')
emb_align.preprocess()
emb_align.latent()
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [07:40<00:00,  4.61s/it]

Aligning SE

[6]:
emb_align.prepare_hgat(slice_order = keys_use, n_neigh_hom=8, c_neigh_het=0.9)
emb_align.data_hgat
emb_align.train_hgat(gamma = 0.8)

adata, atte = emb_align.predict_hgat()
atte.to_csv(f'{result_path}/embedding/attention.csv')
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [25:50<00:00, 10.34s/it]
[7]:
sc.pp.neighbors(adata, use_rep='STAIR')
sc.tl.umap(adata, min_dist=0.2)

from matplotlib.pyplot import rc_context
with rc_context({'figure.figsize': (4,4)}):
    sc.pl.umap(adata, color=['sample', 'ABA_parent'],
               frameon=False, show=True, ncols=2, s=6)
_images/4._Align_new_slices_to_the_ABA_brain_atlas_12_0.png
[8]:
import seaborn as sns

vmax = atte[atte!=1].max().max()
vmin = atte[atte!=1].min().min()
plt.figure(figsize=(8,6.6))
sns.heatmap(atte, vmax=vmax, vmin=vmin)
plt.savefig(f'{result_path}/embedding/attention.pdf',bbox_inches='tight')
plt.show()
plt.close()
_images/4._Align_new_slices_to_the_ABA_brain_atlas_13_0.png

CCF coordinate prediction

[9]:
from STAIR.loc_prediction import Loc_Pred

adata.obsm['spatial_ccf_2d'] = adata.obs[['stereo_ML', 'stereo_DV']].values
adata.obsm['spatial_ccf_3d'] = adata.obs[['stereo_ML', 'stereo_DV', 'stereo_AP']].values

loc_pred = Loc_Pred(adata, atte, batch_key = 'section_index', querys = ['Visium'], result_path = result_path)

# predicting the AP coordinate and the nearest slice of new data
preds, nearest_slices = loc_pred.pred_z(loc_key = 'stereo_AP', num_mnn = 20)
preds, nearest_slices
[9]:
([-1.8182837328385228], ['ST_20A'])
[10]:
# predicting the ML and DV coordinates and the nearest slice of new data`
adata_aligned_query = loc_pred.pred_xy( spatial_key_query = 'spatial',
                                        spatial_key_ref = 'spatial_ccf_2d',
                                        spatial_key_3d = 'spatial_ccf_3d',
                                        emb_key = 'STAIR',
                                        alpha_query = 2,
                                        alpha_ref = 1,
                                      )

adata_aligned_query.obs[['stereo_ML', 'stereo_DV', 'stereo_AP']] = adata_aligned_query.obsm['spatial_ccf_3d']
adata_ref = adata[adata.obs['section_index']!='Visium']
Performing fine alignment of the 1 pair of data...

ABA region annotation

[11]:
from STAIR.ABA_annotation import ABA_anno, plot_spatial_ABA

# get ABA annotation in 2 levels
adata_aligned_query = ABA_anno(adata_aligned_query, 'stereo_ML', 'stereo_DV', 'stereo_AP', spatial_key='spatial')

plot_spatial_ABA([adata[adata.obs['section_index']=='ST_20A'], adata_aligned_query],
                 spatial_key='spatial_ccf_2d', level='level1',
                 title_key='section_index', s=[26,13], figsize=(5, 3), save=f'{result_path}/spatial1.png')

plot_spatial_ABA([adata[adata.obs['section_index']=='ST_20A'], adata_aligned_query], spatial_key='spatial_ccf_2d', level='level2',
                 title_key='section_index', s=[26,13], figsize=(5, 3), save=f'{result_path}/spatial2.png')
_images/4._Align_new_slices_to_the_ABA_brain_atlas_18_0.png
_images/4._Align_new_slices_to_the_ABA_brain_atlas_18_1.png
[12]:
# ABA region was supported by marker genes

sc.pp.scale(adata_aligned_query)

from matplotlib.pyplot import rc_context
with rc_context({'figure.figsize': (4,4.7)}):
    sc.pl.embedding(adata_aligned_query, basis='spatial_ccf_2d', color=['Cabp7', 'Gpr88', 'Rora',
                                                                        'Mbp', 'Gpx3', 'Hpca'],
                    frameon=False, show=True, ncols=6, s=70, vmin=-1, vmax=2.5)
    plt.savefig(f'{result_path}/embedding/gene.png', bbox_inches='tight', dpi=300)
    plt.close()

_images/4._Align_new_slices_to_the_ABA_brain_atlas_19_0.png

3D Visulization

[13]:
import plotly.express as px
import matplotlib
import pandas as pd
import STAIR
package_path = os.path.dirname(STAIR.__file__)

adata_all = sc.concat([adata_ref, adata_aligned_query])
df = adata_all.obs
df['ABA_name_level1'] = df['ABA_name_level1'].astype('category')
domains_level1 = df['ABA_name_level1'].drop_duplicates().values.tolist()
domains_level2 = df['ABA_name_level2'].drop_duplicates().values.tolist()
name_color = pd.read_csv(package_path+'/ABAanno/ontology.csv', index_col=0).set_index('name')['color_hex_triplet'].to_dict()

## level 1 regions
fig = px.scatter_3d(df,
                    x='stereo_AP',
                    y='stereo_ML',
                    z='stereo_DV',
                    color='ABA_name_level1',
                    opacity=1,
                    color_discrete_map={c:name_color[c] for c in domains_level1})
fig.update_traces(marker_size = 5)
fig.update_layout(
    height=1000,
    width=1000,
    scene=dict(
        aspectratio=dict(x=1.6, y=1, z=1.4)
        ),
    margin=dict(r=0, l=0, b=0, t=0))
fig.write_html(f'{result_path}/3d_align_level1.html', auto_open=False)
fig
[14]:
## level 2 regions
df['plot_size'] = [0.1 if df['section_index'][i]=='10X_1' else 0.1 for i in range(df.shape[0])]

fig = px.scatter_3d(df,
                    x = 'stereo_AP',
                    y = 'stereo_ML',
                    z = 'stereo_DV',
                    color ='ABA_name_level2',
                    opacity = 1,
                    color_discrete_map={c:name_color[c] for c in domains_level2})
fig.update_traces(marker_size = 5)
fig.update_layout(
    height=1000,
    width=1000,
    scene=dict(
        aspectratio=dict(x=1.6, y=1, z=1.4)
        ),
    margin=dict(r=0, l=0, b=0, t=0))
fig.write_html(f'{result_path}/3d_align_level2.html', auto_open=False)
fig
[15]:
## slices visualization
slices = adata_all.obs[['stereo_AP', 'section_index']].drop_duplicates()['section_index'].tolist()
slices_to_color_map = {slice:plt.cm.get_cmap('tab20')(15) if slice!='Visium' else plt.cm.get_cmap('tab20')(6) for slice in slices }
df['plot_size'] = [0.5 if df['section_index'][i]=='Visium' else 0.1 for i in range(df.shape[0])]

fig = px.scatter_3d(df,
                    x='stereo_AP',
                    y='stereo_ML',
                    z='stereo_DV',
                    color='section_index',
                    size = 'plot_size',
                    opacity=0.3,
                    color_discrete_sequence=[matplotlib.colors.to_hex(slices_to_color_map[c]) for c in slices])
fig.update_layout(
    height=1000,
    width=1000,
    scene=dict(
        aspectratio=dict(x=1.6, y=1, z=1.4) #改变画布空间比例为1:1:1
        ),
    margin=dict(r=0, l=0, b=0, t=0))
fig.write_html(f'{result_path}/3d_align_slices.html', auto_open=False)
fig

Save

[16]:
adata_all.write(f'{result_path}/adata.h5ad')