{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Prototypical networks for few-shot cell type annotations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is a part of the ISMB tutorial [Meta-learning for bridging labeled and unlabeled data in biomedicine](http://snap.stanford.edu/metalearning-ismb/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we wil use [Prototypical networks](https://arxiv.org/abs/1703.05175) to learn cell type annotations from only few-labeled as examples. We will use meta-learning approach to learn across different tissues and then generalize to a new, yet unseen tissues given only few labels. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading Tabula Muris dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use [Tabula Muris](https://www.nature.com/articles/s41586-020-2496-1) mouse cell atlas single-cell data. This dataset can be downloaded at [http://snap.stanford.edu/mars/data/tms-facs-mars.tar.gz](http://snap.stanford.edu/mars/data/tms-facs-mars.tar.gz). " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import scanpy.api as sc\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# set the path to the data\n", "dataset_root = '../tabula-muris-senis-facs_mars.h5ad'" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def load_sc_data(root):\n", " \"\"\"Preprocess single cell data\"\"\"\n", " sc_data = sc.read_h5ad(dataset_root)\n", " sc_data = sc_data[sc_data.obs['age']=='3m'] # reduce file size by taking only 3months old mouse to be able to locally run the example\n", " sc.pp.filter_cells(sc_data, min_counts=5000)\n", " sc.pp.filter_cells(sc_data, min_genes=500)\n", " \n", " sc.pp.normalize_per_cell(sc_data, counts_per_cell_after=1e4) \n", " sc_data = sc.pp.filter_genes_dispersion(sc_data, subset = False, min_disp=.5, max_disp=None, \n", " min_mean=.0125, max_mean=10, n_bins=20, n_top_genes=None, \n", " log=True, copy=True)\n", " sc_data = sc_data[:,sc_data.var.highly_variable]\n", " sc.pp.log1p(sc_data)\n", " sc.pp.scale(sc_data, max_value=10, zero_center=True)\n", " sc_data.X[np.isnan(sc_data.X)] = 0\n", " \n", " return sc_data" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Trying to set attribute `.obs` of view, making a copy.\n" ] } ], "source": [ "tm_data = load_sc_data(dataset_root)" ] }, { "cell_type": "code", "execution_count": 676, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sc.pp.pca(tm_data)\n", "sc.tl.tsne(tm_data)\n", "sc.pl.tsne(tm_data, color = 'tissue')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting model and training parameters" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import numpy as np\n", "import torch\n", "import os\n", "import pandas as pd\n", "import warnings\n", "from tqdm import tqdm\n", "\n", "import torch.utils.data as data\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "from torch.nn.modules import Module\n", "\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Setting model parameters\n", "dataset_root = '/Users/maria/Desktop/tabula-muris-senis-facs_mars.h5ad'\n", "epochs= 100\n", "iterations = 100\n", "learning_rate = 0.001\n", "lr_scheduler_step = 20\n", "lr_scheduler_gamma = 0.5\n", "classes_per_it_tr = 5 # number of classes to sample at each iteration\n", "classes_per_it_val = 5\n", "num_support_tr = 1 # support set size during training\n", "num_query_tr = 1 # query set size\n", "num_support_val = 1 # support set size during evaluation\n", "num_query_val = 1\n", "manual_seed = 25" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "np.random.seed(manual_seed)\n", "torch.manual_seed(manual_seed)\n", "torch.cuda.manual_seed(manual_seed)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# choosing cross-tisue split for training validation and test set\n", "data_split = {'train': ['Heart','Aorta','Kidney'],\n", " 'val':['Lung'],\n", " 'test':['Liver']}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data loaders " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class CellDataset(data.Dataset):\n", " \"\"\"Examples and labels for single cell dataset\"\"\"\n", " \n", " def __init__(self, adata, mode, min_samples=20):\n", " \n", " super(CellDataset, self).__init__()\n", " \n", " selected_tissues = set(data_split[mode])\n", " adata = adata[adata.obs['tissue'].isin(selected_tissues)]\n", " \n", " # filter cell types with less than min_samples cells\n", " filtered_index = adata.obs.groupby([\"cell_ontology_class_reannotated\"]) \\\n", " .filter(lambda group: len(group) >= min_samples) \\\n", " .reset_index()['index']\n", " adata = adata[filtered_index]\n", " \n", " # convert label to torch tensor y\n", " targets = adata.obs['cell_ontology_class_reannotated'].cat.codes\n", " adata.obs['label'] = targets\n", "\n", " self.x = adata.X\n", " self.y = adata.obs['label']\n", "\n", " print(\"*** Dataset: Found %d items \" % self.x.shape[0])\n", " print(\"*** Dataset: Found %d classes\" % len(np.unique(self.y)))\n", "\n", " shape = 1, self.x.shape[1]\n", " shape = self.x.shape[1]\n", " self.x = [torch.from_numpy(inst).view(shape) for inst in self.x]\n", " self.y = tuple(self.y.tolist())\n", "\n", "\n", " def __getitem__(self, idx):\n", " return self.x[idx], self.y[idx]\n", "\n", " def __len__(self):\n", " return self.nitems\n", "\n", " def get_dim(self):\n", " return self.x[0].shape[0]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class PrototypicalBatchSampler(object):\n", " \"\"\"Yield a batch of indexes at each iteration. At every iteration the batch indexes corresponds to 'num_support' + 'num_query' samples\n", " for 'classes_per_it' random classes.\"\"\"\n", "\n", " def __init__(self, labels, classes_per_it, num_samples, iterations):\n", " super(PrototypicalBatchSampler, self).__init__()\n", " self.labels = labels\n", " self.classes_per_it = classes_per_it\n", " self.sample_per_class = num_samples\n", " self.iterations = iterations\n", "\n", " self.classes, self.counts = np.unique(self.labels, return_counts=True)\n", " self.classes = torch.LongTensor(self.classes)\n", "\n", " # create a matrix, indexes, of dim: classes X max(elements per class)\n", " # fill it with nans\n", " # for every class c, fill the relative row with the indices samples belonging to c\n", " # in numel_per_class we store the number of samples for each class/row\n", " self.idxs = range(len(self.labels))\n", " print('Number of examples per class: ')\n", " print(self.counts)\n", " self.indexes = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan\n", " self.indexes = torch.Tensor(self.indexes)\n", " self.numel_per_class = torch.zeros_like(self.classes)\n", " for idx, label in enumerate(self.labels):\n", " label_idx = np.argwhere(self.classes == label).item()\n", " self.indexes[label_idx, np.where(np.isnan(self.indexes[label_idx]))[0][0]] = idx\n", " self.numel_per_class[label_idx] += 1\n", "\n", " def __iter__(self):\n", " spc = self.sample_per_class\n", " cpi = self.classes_per_it\n", " \n", " for it in range(self.iterations):\n", " batch_size = spc * cpi\n", " batch = torch.LongTensor(batch_size)\n", " c_idxs = torch.randperm(len(self.classes))[:cpi]\n", " \n", " for i, c in enumerate(self.classes[c_idxs]):\n", " s = slice(i * spc, (i + 1) * spc)\n", " label_idx = torch.arange(len(self.classes)).long()[self.classes == c].item()\n", " sample_idxs = torch.randperm(self.numel_per_class[label_idx])[:spc] \n", " batch[s] = self.indexes[label_idx][sample_idxs] # error when a class does not have spc examples\n", " batch = batch[torch.randperm(len(batch))]\n", " \n", " yield batch\n", "\n", " def __len__(self):\n", " return self.iterations" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def init_sampler(labels, mode):\n", " if mode=='train':\n", " classes_per_it = classes_per_it_tr\n", " num_samples = num_support_tr + num_query_tr\n", " else:\n", " classes_per_it = classes_per_it_val\n", " num_samples = num_support_val + num_query_val\n", "\n", " return PrototypicalBatchSampler(labels, classes_per_it, num_samples, iterations)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def init_dataloader(dataset, mode):\n", " \"\"\"Initializing data loaders for train, val or test mode.\"\"\"\n", " \n", " dataset = CellDataset(dataset, mode)\n", " sampler = init_sampler(dataset.y, mode)\n", " \n", " n_classes = len(np.unique(dataset.y))\n", " if mode=='train' and n_classes < classes_per_it_tr:\n", " raise(Exception('There are not enough classes in the dataset to satisfy the chosen classes_per_it. Decrease the classes_per_it_tr option.'))\n", " if (mode=='val' or mode=='test') and n_classes < classes_per_it_val:\n", " raise(Exception('There are not enough classes in the dataset to satisfy the chosen classes_per_it. Decrease the classes_per_it_val option.'))\n", " dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)\n", " \n", " return dataloader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model definition" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For model we will use simple 2-layer fully connected neural network." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "def full_block(in_features, out_features, p_drop):\n", " return nn.Sequential(\n", " nn.Linear(in_features, out_features, bias=True),\n", " nn.ReLU(),\n", " nn.Dropout(p=p_drop),\n", " )" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class ProtoNet(nn.Module):\n", " \n", " def __init__(self, x_dim, hid_dim=64, z_dim=64, p_drop=0.2):\n", " super(ProtoNet, self).__init__()\n", " self.encoder = nn.Sequential(\n", " full_block(x_dim, hid_dim, p_drop),\n", " full_block(hid_dim, z_dim, p_drop),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.encoder(x)\n", " return x.view(x.size(0), -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss definition" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def euclidean_dist(x, y):\n", " \"\"\"Compute euclidean distance between two tensors.\"\"\"\n", " n = x.size(0)\n", " m = y.size(0)\n", " d = x.size(1)\n", " if d != y.size(1):\n", " raise Exception\n", "\n", " x = x.unsqueeze(1).expand(n, m, d)\n", " y = y.unsqueeze(0).expand(n, m, d)\n", "\n", " return torch.pow(x - y, 2).sum(2)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "def get_idxs(n_support, target):\n", " \"\"\"Get indexes of support and query sets.\"\"\"\n", " \n", " classes = torch.unique(target)\n", " support_idxs = list(map(lambda c: target.eq(c).nonzero()[:n_support].squeeze(1), classes))\n", " l = list(map(lambda c: target.eq(c).nonzero()[n_support:], classes))\n", " query_idxs = torch.cat(list(map(lambda c: target.eq(c).nonzero()[n_support:], classes)))\n", " query_idxs = query_idxs.view(-1)\n", " \n", " return support_idxs, query_idxs" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "def prototypical_loss(input1, target, n_support):\n", " \n", " support_idxs, query_idxs = get_idxs(n_support, target)\n", " prototypes = torch.stack([input1[idx_class].mean(0) for idx_class in support_idxs])\n", "\n", " query_samples = input1[query_idxs]\n", " dists = euclidean_dist(query_samples, prototypes)\n", " log_p_y = F.log_softmax(-dists, dim=1)\n", " \n", " target = target[query_idxs]\n", " uniq, counts = np.unique(target.numpy(), return_counts=True)\n", " uniq_counts = dict(zip(uniq, counts))\n", " uniq_app_order = pd.unique(target.numpy())\n", " target_inds = torch.Tensor()\n", " \n", " for idx, class_id in enumerate(uniq_app_order):\n", " target_inds = torch.cat((target_inds, torch.Tensor([idx]).repeat(uniq_counts[class_id])))\n", " target_inds = target_inds.long()\n", " \n", " loss = torch.nn.NLLLoss()\n", " loss_val = loss(log_p_y, target_inds)\n", " \n", " _, y_hat = log_p_y.max(1)\n", " acc_val = y_hat.eq(target_inds.squeeze()).float().mean()\n", " \n", " return loss_val, acc_val" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "class PrototypicalLoss(Module):\n", " \"\"\"Loss class.\"\"\"\n", " \n", " def __init__(self, n_support):\n", " super(PrototypicalLoss, self).__init__()\n", " self.n_support = n_support\n", "\n", " def forward(self, input, target):\n", " return prototypical_loss(input, target, self.n_support)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model training and testing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Defining functions to train and test the model." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def train(tr_dataloader, val_dataloader, model, optim, lr_scheduler):\n", " '''\n", " Train the model with the prototypical learning algorithm\n", " '''\n", "\n", " train_loss = []\n", " train_acc = []\n", " val_loss = []\n", " val_acc = []\n", " best_acc = 0\n", "\n", " for epoch in range(epochs): # outer loop over epochs\n", " print('*** Epoch: {} ***'.format(epoch))\n", " tr_iter = iter(tr_dataloader)\n", " \n", " model.train() # model training\n", " \n", " for batch in tqdm(tr_iter): # inner loop over episodes\n", " optim.zero_grad()\n", " \n", " x, y = batch\n", " model_output = model(x) \n", " loss, acc = prototypical_loss(model_output, y, num_support_tr)\n", " loss.backward()\n", " optim.step()\n", " \n", " train_loss.append(loss.item())\n", " train_acc.append(acc.item())\n", " avg_loss = np.mean(train_loss[-iterations:])\n", " avg_acc = np.mean(train_acc[-iterations:])\n", " print('Avg train loss: {}, Avg train acc: {}'.format(avg_loss, avg_acc))\n", " lr_scheduler.step()\n", " \n", " val_iter = iter(val_dataloader)\n", " \n", " model.eval() # model evaluation\n", " for batch in tqdm(val_iter):\n", " x, y = batch\n", " model_output = model(x) \n", " loss, acc = prototypical_loss(model_output, y, num_support_val)\n", " val_loss.append(loss.item())\n", " val_acc.append(acc.item())\n", " avg_loss = np.mean(val_loss[-iterations:])\n", " avg_acc = np.mean(val_acc[-iterations:])\n", " info = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(best_acc)\n", " print('Avg val loss: {}, Avg val acc: {}{}'.format(avg_loss, avg_acc, info))\n", " if avg_acc >= best_acc:\n", " best_acc = avg_acc\n", " best_state = model.state_dict()\n", " \n", " return best_state" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "def test(test_dataloader, model):\n", " avg_acc = list()\n", " epoch = iter(test_dataloader)\n", " for batch in epoch:\n", " x, y = batch\n", " model_output = model(x)\n", " _, acc = prototypical_loss(model_output, y, num_support_val)\n", " avg_acc.append(acc.item())\n", " avg_acc = np.mean(avg_acc)\n", " print('Test accuracy: {}'.format(avg_acc))\n", "\n", " return avg_acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train and test the model on the Tabula Muris data" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Trying to set attribute `.obs` of view, making a copy.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "*** Dataset: Found 5577 items \n", "*** Dataset: Found 21 classes\n", "Number of examples per class: \n", "[ 21 46 189 159 164 1179 46 113 2184 25 124 55 75 69\n", " 32 490 281 56 65 158 46]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Trying to set attribute `.obs` of view, making a copy.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "*** Dataset: Found 1675 items \n", "*** Dataset: Found 17 classes\n", "Number of examples per class: \n", "[ 55 36 27 140 420 21 68 51 129 52 362 29 56 39 55 96 39]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Trying to set attribute `.obs` of view, making a copy.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "*** Dataset: Found 702 items \n", "*** Dataset: Found 5 classes\n", "Number of examples per class: \n", "[ 34 47 189 401 31]\n" ] } ], "source": [ "# main\n", "tr_dataloader = init_dataloader(tm_data, 'train')\n", "val_dataloader = init_dataloader(tm_data, 'val')\n", "test_dataloader = init_dataloader(tm_data, 'test')" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "model = ProtoNet(tr_dataloader.dataset.get_dim())\n", "optim = torch.optim.Adam(params=model.parameters(), lr=learning_rate)\n", "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim,\n", " gamma=lr_scheduler_gamma,\n", " step_size=lr_scheduler_step)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 21%|██ | 21/100 [00:00<00:00, 209.03it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "*** Epoch: 0 ***\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 100/100 [00:00<00:00, 238.17it/s]\n", "100%|██████████| 100/100 [00:00<00:00, 516.85it/s]\n", " 0%| | 0/100 [00:00