diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ada5f27 Binary files /dev/null and b/.DS_Store differ diff --git a/FunctionProgramExamples/.DS_Store b/FunctionProgramExamples/.DS_Store new file mode 100644 index 0000000..65267a9 Binary files /dev/null and b/FunctionProgramExamples/.DS_Store differ diff --git a/Adam_prediction.py b/FunctionProgramExamples/Adam_prediction.py similarity index 100% rename from Adam_prediction.py rename to FunctionProgramExamples/Adam_prediction.py diff --git a/examples/MichiGAN_mean.py b/FunctionProgramExamples/examples/MichiGAN_mean.py similarity index 100% rename from examples/MichiGAN_mean.py rename to FunctionProgramExamples/examples/MichiGAN_mean.py diff --git a/examples/MichiGAN_sample.py b/FunctionProgramExamples/examples/MichiGAN_sample.py similarity index 100% rename from examples/MichiGAN_sample.py rename to FunctionProgramExamples/examples/MichiGAN_sample.py diff --git a/examples/beta_tcvae.py b/FunctionProgramExamples/examples/beta_tcvae.py similarity index 100% rename from examples/beta_tcvae.py rename to FunctionProgramExamples/examples/beta_tcvae.py diff --git a/examples/example_InfoWGAN_GP.ipynb b/FunctionProgramExamples/examples/example_InfoWGAN_GP.ipynb similarity index 100% rename from examples/example_InfoWGAN_GP.ipynb rename to FunctionProgramExamples/examples/example_InfoWGAN_GP.ipynb diff --git a/examples/example_MichiGAN_mean.ipynb b/FunctionProgramExamples/examples/example_MichiGAN_mean.ipynb similarity index 100% rename from examples/example_MichiGAN_mean.ipynb rename to FunctionProgramExamples/examples/example_MichiGAN_mean.ipynb diff --git a/examples/example_MichiGAN_sample.ipynb b/FunctionProgramExamples/examples/example_MichiGAN_sample.ipynb similarity index 100% rename from examples/example_MichiGAN_sample.ipynb rename to FunctionProgramExamples/examples/example_MichiGAN_sample.ipynb diff --git a/examples/example_VAE.ipynb b/FunctionProgramExamples/examples/example_VAE.ipynb similarity index 100% rename from examples/example_VAE.ipynb rename to FunctionProgramExamples/examples/example_VAE.ipynb diff --git a/examples/example_WGAN_GP.ipynb b/FunctionProgramExamples/examples/example_WGAN_GP.ipynb similarity index 100% rename from examples/example_WGAN_GP.ipynb rename to FunctionProgramExamples/examples/example_WGAN_GP.ipynb diff --git a/examples/example_beta_TCVAE.ipynb b/FunctionProgramExamples/examples/example_beta_TCVAE.ipynb similarity index 100% rename from examples/example_beta_TCVAE.ipynb rename to FunctionProgramExamples/examples/example_beta_TCVAE.ipynb diff --git a/examples/infowgangp.py b/FunctionProgramExamples/examples/infowgangp.py similarity index 100% rename from examples/infowgangp.py rename to FunctionProgramExamples/examples/infowgangp.py diff --git a/examples/vae.py b/FunctionProgramExamples/examples/vae.py similarity index 100% rename from examples/vae.py rename to FunctionProgramExamples/examples/vae.py diff --git a/examples/wgangp.py b/FunctionProgramExamples/examples/wgangp.py similarity index 100% rename from examples/wgangp.py rename to FunctionProgramExamples/examples/wgangp.py diff --git a/lib.py b/FunctionProgramExamples/lib.py similarity index 100% rename from lib.py rename to FunctionProgramExamples/lib.py diff --git a/nets.py b/FunctionProgramExamples/nets.py similarity index 100% rename from nets.py rename to FunctionProgramExamples/nets.py diff --git a/README.md b/README.md index 940160b..f669896 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # MichiGAN: Learning disentangled representations of single-cell data for high-quality generation -## Predicting unobserved cell states from disentangled representations of single-cell data using generative adversarial networks +## Sampling from disentangled representations of single-cell data using generative adversarial networks -The current folder contains files for implementing **VAE/beta-TCVAE, WGAN-GP, InfoWGAN-GP, and MichiGAN** on single-cell RNA-seq data. See our preprint for details: -[Predicting unobserved cell states from disentangled representations of single-cell data using generative adversarial networks](https://www.biorxiv.org/content/10.1101/2021.01.15.426872v1) (Yu and Welch, 2020+). +The current folder contains files for implementing **PCA, GMM, VAE/beta-TCVAE, WGAN-GP, InfoWGAN-GP, ssInfoWGAN-GP, CWGAN-GP and MichiGAN** on single-cell RNA-seq data. See our preprint for details: +[Sampling from disentangled representations of single-cell data using generative adversarial networks](https://www.biorxiv.org/content/10.1101/2021.01.15.426872v1) (Yu and Welch, 2021+). We have a [presentation video](https://youtu.be/5tsccPMPzLQ) for [Learning Meaningful Representations of Life Workshop](https://www.lmrl.org/) at NeurIPS 2020, where we named our framework as `DRGAN` and changed the name to `MichiGAN` afterwards. ## List of Files: 1) `/data` is the folder containing the real scRNA-seq dataset of Tabula Muris heart data. Users can download the SCANPY-processed data on https://www.dropbox.com/sh/xseb0u6p01te3vr/AACuskVfswUFn5MroEFrqI-Xa?dl=0. -2) `/examples` is the folder for the experiments of\ +2) `/FunctionProgramExamples/examples` is the folder for the experiments of\ (1) `vae.py`: VAE; \ (2) `beta_tcvae.py`: beta-TCVAE;\ (3) `wgangp.py`: WGAN-GP;\ diff --git a/data/README.md b/data/README.md index 734a886..1bfa197 100644 --- a/data/README.md +++ b/data/README.md @@ -1,4 +1,4 @@ -The folder has the Tabula Muris heart data that are currently stored on dropbox and can be accessed via https://www.dropbox.com/sh/xseb0u6p01te3vr/AACuskVfswUFn5MroEFrqI-Xa?dl=0 +The folder has the Tabula Muris heart data that are currently stored on dropbox and can be accessed [here](https://www.dropbox.com/sh/xseb0u6p01te3vr/AACuskVfswUFn5MroEFrqI-Xa?dl=0). List of Files: 1) `TabulaMurisHeart_Processed.npy` is the SCANPY-processed Tabula Muris heart data (4221 cells/rows by 4062 genes/columns) diff --git a/metrics/DisentangleMetrics.py b/metrics/DisentangleMetrics.py new file mode 100644 index 0000000..4dea88d --- /dev/null +++ b/metrics/DisentangleMetrics.py @@ -0,0 +1,652 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import numpy as np +from random import sample + +import scipy +from scipy import stats +from sklearn.svm import SVC +from sklearn.neighbors import KNeighborsClassifier + +import tensorflow as tf +from tensorflow import distributions as ds + +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt + +class DisentangledRepresentation: + """ + Disentanglement performance class + """ + + def __init__(self): + super().__init__() + self.clf = SVC(kernel = 'linear') + + + def KS_test(self, z_sample_data, z_sample_data_2 = None): + """ + Kolmogorov-Smirnov test between two latent tensors on each dimension + """ + + p_value_list = [] + if z_sample_data_2 is None: + norm_data = np.random.normal(0, 1, size = z_sample_data.shape[0]) + for i in range(z_sample_data.shape[1]): + stat, p_value = stats.ks_2samp(z_sample_data[:, i], norm_data) + p_value_list.append(p_value) + else: + assert z_sample_data.shape[1] == z_sample_data_2.shape[1] + for i in range(z_sample_data.shape[1]): + stat, p_value = stats.ks_2samp(z_sample_data[:, i], z_sample_data_2[:, i]) + p_value_list.append(p_value) + + return p_value_list + + def CorrValue(self, z_data, GroundTruthVar, is_spearman = True): + """ + correlation between each latent dimension of z_data and GroundTruthVar + """ + + cor_value = np.zeros((z_data.shape[1])) + if is_spearman: + for t in range(z_data.shape[1]): + cor_value[t] = scipy.stats.spearmanr(z_data[:, t], GroundTruthVar, nan_policy = 'omit')[0] + else: + for t in range(z_data.shape[1]): + cor_value[t] = scipy.stats.pearsonr(z_data[:, t], GroundTruthVar, nan_policy = 'omit')[0] + return cor_value + + def CorrGap(self, cor_matrix): + """ + correlation gap based on a correlation matrix + """ + + cor_output = cor_matrix.copy() + cor_output = np.abs(cor_output) + + cor_output.sort(axis = 1) + metric = np.mean(cor_output[:, -1] - cor_output[:, -2]) + + return metric + + + def DataByCode(self, umap_data, z_data, path_figure_save): + """ + UMAP plots of data colored by each dimension of latent z_data + """ + + dict_use = {} + for h in range(z_data.shape[1]): + dict_use["Z" + str(h+1)] = h + 1 + + min_x, min_y = np.floor(umap_data['x-umap'].min()), np.floor(umap_data['y-umap'].min()) + max_x, max_y = np.ceil(umap_data['x-umap'].max()), np.ceil(umap_data['y-umap'].max()) + + newfig = plt.figure(figsize=[20,6]) + for m in range(len(dict_use)): + name_i = list(dict_use.keys())[m] + num_i = dict_use[name_i] + ax1 = newfig.add_subplot(2, 5, num_i) + cb1 = ax1.scatter(umap_data['x-umap'], umap_data['y-umap'], s= 1, c = z_data[:, m], cmap= "plasma") + ax1.tick_params(axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.tick_params(axis='y', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.get_yaxis().set_ticks([]) + ax1.set_ylim(min_y, max_y) + ax1.set_xlim(min_x, max_x) + ax1.set_title(name_i) + + newfig.savefig(path_figure_save, dpi=300, useDingbats = False) + + + def MutualInformation(self, data, metadata, dictmeta, list_GroundTruthVars, z_dim, + sess, MarginalEntropy, X_v, if_VarBalanced = True): + """ + mutual information for PCA. use self.mu and self.pca_posterior_mean(X) as X_v and data + """ + + cond_en = np.zeros((len(list_GroundTruthVars), z_dim)) + + if data.shape[0] <= 12000: + input_mig, meta_mig = data, metadata + else: + idx_mig = sample(range(data.shape[0]), 12000) + input_mig = data[idx_mig, :] + meta_mig = metadata.iloc[idx_mig, :] + + mar_en_v = sess.run(MarginalEntropy, {X_v: input_mig}) + + # with several ground truth variables + for k in range(len(list_GroundTruthVars)): + k_var = list_GroundTruthVars[k] + list_k = dictmeta[k_var] + + for k_use in list_k: + value_list_k = np.where(np.array(meta_mig[k_var]) == k_use)[0] + x_value_list_k = input_mig[value_list_k, :] + en_i = sess.run(MarginalEntropy, {X_v: x_value_list_k}) + cond_en[k, :] += en_i * float(value_list_k.shape[0])/float(meta_mig.shape[0]) + + if if_VarBalanced: + factor_entropies = np.log([ len(dictmeta[i]) for i in list_GroundTruthVars]) + else: + factor_entropies = np.array([scipy.stats.entropy(meta_mig[i].value_counts()) for i in list_GroundTruthVars]) + + MIGValue, NormMutualInfo = self.MIGMetric(mar_en_v, cond_en, factor_entropies) + + return MIGValue, NormMutualInfo + + def MutualInformationWithMissing(self, data, metadata, dictmeta, list_GroundTruthVars, z_dim, + sess, MarginalEntropy, X_v, training = None, if_VarBalanced = True): + """ + general mutual information for PCA or VAE on data with missing metadata + """ + cond_en = np.zeros((len(list_GroundTruthVars), z_dim)) + mar_en = np.zeros((len(list_GroundTruthVars), z_dim)) + + if data.shape[0] <= 12000: + input_mig, meta_mig = data, metadata + else: + idx_mig = sample(range(data.shape[0]), 12000) + input_mig = data[idx_mig, :] + meta_mig = metadata.iloc[idx_mig, :] + + for k in range(len(list_GroundTruthVars)): + k_var = list_GroundTruthVars[k] + input_obs = input_mig[-meta_mig[k_var].isna()] + if training is None: + mar_en[k, :] = sess.run(MarginalEntropy, {X_v: input_obs}) + else: + mar_en[k, :] = sess.run(MarginalEntropy, {X_v: input_obs, training: False}) + + # with several ground truth variables + for k in range(len(list_GroundTruthVars)): + k_var = list_GroundTruthVars[k] + list_k = dictmeta[k_var] + + for k_use in list_k: + value_list_k = np.where(np.array(meta_mig[k_var]) == k_use)[0] + x_value_list_k = input_mig[value_list_k, :] + + if training is None: + en_i = sess.run(MarginalEntropy, {X_v: x_value_list_k}) + else: + en_i = sess.run(MarginalEntropy, {X_v: x_value_list_k, training: False}) + + cond_en[k, :] += en_i /float(len(list_k)) + + if if_VarBalanced: + factor_entropies = np.log([ len(dictmeta[i]) for i in list_GroundTruthVars]) + else: + factor_entropies = np.array([scipy.stats.entropy(meta_mig[i].value_counts()) for i in list_GroundTruthVars]) + + MutualInfo = mar_en - cond_en + mi_normed = MutualInfo/factor_entropies[:, None] + mi_output = mi_normed.copy() + + mi_normed.sort(axis = 1) + metric = np.mean(mi_normed[:, -1] - mi_normed[:, -2]) + + return metric, mi_output + + def MutualInformationVAE(self, data, metadata, dictmeta, list_GroundTruthVars, z_dim, + sess, MarginalEntropy, X_v, training, if_VarBalanced = True): + """ + mutual information for VAEs + """ + + cond_en = np.zeros((len(list_GroundTruthVars), z_dim)) + + if data.shape[0] <= 12000: + input_mig, meta_mig = data, metadata + else: + idx_mig = sample(range(data.shape[0]), 12000) + input_mig = data[idx_mig, :] + meta_mig = metadata.iloc[idx_mig, :] + + mar_en_v = sess.run(MarginalEntropy, {X_v: input_mig, training: False}) + + # with several ground truth variables + for k in range(len(list_GroundTruthVars)): + k_var = list_GroundTruthVars[k] + list_k = dictmeta[k_var] + + for k_use in list_k: + value_list_k = np.where(np.array(meta_mig[k_var]) == k_use)[0] + x_value_list_k = input_mig[value_list_k, :] + en_i = sess.run(MarginalEntropy, {X_v: x_value_list_k, training: False}) + cond_en[k, :] += en_i * float(value_list_k.shape[0])/float(meta_mig.shape[0]) + + if if_VarBalanced: + factor_entropies = np.log([ len(dictmeta[i]) for i in list_GroundTruthVars]) + else: + factor_entropies = np.array([scipy.stats.entropy(meta_mig[i].value_counts()) for i in list_GroundTruthVars]) + + MIGValue, NormMutualInfo = self.MIGMetric(mar_en_v, cond_en, factor_entropies) + + return MIGValue, NormMutualInfo + + + def MIGMetric(self, marginal_entropies, con_entropies, factor_entropies): + """ + calculate mutual information gap (MIG) based on marginal and conditional entropies + """ + MutualInfo = marginal_entropies[None] - con_entropies + mi_normed = MutualInfo/factor_entropies[:, None] + mi_output = mi_normed.copy() + + mi_normed.sort(axis = 1) + metric = np.mean(mi_normed[:, -1] - mi_normed[:, -2]) + + return metric, mi_output + + def PlotBarCor(self, cor_matrix, path_figure_save = None): + """ + Plot correlation bar plots + """ + # Correlation GAP + rep_list = ["Z" + str(i + 1) for i in range(cor_matrix.shape[1])] + newfig = plt.figure(figsize=[8,6]) + ax1 = newfig.add_subplot(2, 2, 1) + cb1 = ax1.bar(rep_list, cor_matrix[0, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(-1, 1) + ax1.set_title("Batch", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 2) + cb1 = ax1.bar(rep_list, cor_matrix[1, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(-1, 1) + ax1.set_title("Path", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 3) + cb1 = ax1.bar(rep_list, cor_matrix[2, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(-1, 1) + ax1.set_title("Step", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 4) + cb1 = ax1.bar(rep_list, cor_matrix[3, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(-1, 1) + ax1.set_title("Library Size Quartile", fontsize = 18) + + newfig.text(0.5, 0.04, 'Representations', ha='center', fontsize = 18) + newfig.text(0.04, 0.5, 'Spearman Correlation', va='center', rotation='vertical', fontsize = 18) + + if path_figure_save is not None: + newfig.savefig(path_figure_save, dpi=300, useDingbats = False) + + def PlotBarMI(self, norm_mi, path_figure_save = None): + """ + Plot normalized mutual information bar plots + """ + # normalized mutual information + rep_list = ["Z" + str(i + 1) for i in range(norm_mi.shape[1])] + + newfig = plt.figure(figsize=[8,6]) + + ax1 = newfig.add_subplot(2, 2, 1) + cb1 = ax1.bar(rep_list, norm_mi[0, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(0, 1) + ax1.set_title("Batch", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 2) + cb1 = ax1.bar(rep_list, norm_mi[1, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(0, 1) + ax1.set_title("Path", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 3) + cb1 = ax1.bar(rep_list, norm_mi[2, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(0, 1) + ax1.set_title("Step", fontsize = 18) + + ax1 = newfig.add_subplot(2,2, 4) + cb1 = ax1.bar(rep_list, norm_mi[3, :], color= "blue") + + ax1.tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + ax1.set_ylim(0, 1) + ax1.set_title("Library Size Quartile", fontsize = 18) + + + newfig.text(0.5, 0.04, 'Representations', ha='center', fontsize = 18) + newfig.text(0.04, 0.5, 'Normalized Mutual Information', va='center', rotation='vertical', fontsize = 18) + + if path_figure_save is not None: + newfig.savefig(path_figure_save, dpi=300, useDingbats = False) + + + def latent_space_entropies(self, tf_JointEntropy, tf_sess, tf_z_mu, tf_z_std, tf_z_sample, tf_training, + z_mu, z_std, z_sample): + """ + calculate latent space entropy based on latent values and given posterior distributions + """ + feed_dict = { + tf_z_mu: z_mu, tf_z_std: z_std, + tf_z_sample: z_sample, tf_training: False + } + + z_JointEntropy = tf_sess.run(tf_JointEntropy, feed_dict = feed_dict) + + return z_JointEntropy + + def FactorVAEMetric(self, input_data, data_meta, list_GroundTruthVars, dict_meta, latent_dim, input_latent, + K_samples = 10000, L_samples = 40): + """ + calculate the disentanglement metrics of beta-VAE and FactorVAE + """ + k_list = [] + + mean_list, std_list = None, None + + all_c = input_latent + # put generated c information into its standard deviation + all_c_std = all_c.std(axis = 0, keepdims = True) + + for s in range(K_samples): + + k = sample(range(len(list_GroundTruthVars)), 1) + + # FactorVAE metric + fixed_factor = list_GroundTruthVars[k[0]] + fixed_factor_value_list = dict_meta[fixed_factor] + fixed_factor_value_kim = sample(fixed_factor_value_list, 1)[0] + # the data with this fixed value + list_fixed_value_kim = np.where(data_meta[fixed_factor] == fixed_factor_value_kim)[0].tolist() + + # sample without replacement for FactorVAE metric + indexkim = np.random.choice(list_fixed_value_kim , L_samples, replace = False) + selectkimdata = input_data[indexkim, :] + + ckim = input_latent[indexkim, :] + ckim_scale = ckim/all_c_std + + diff_list = None + + # beta-vae metric + for l in range(L_samples): + + fixed_factor_value = sample(fixed_factor_value_list, 1)[0] + list_fixed_value = np.where(data_meta[fixed_factor] == fixed_factor_value)[0].tolist() + + indexbeta = np.random.choice(list_fixed_value, 2, replace = False) + + z1 = input_latent[[indexbeta[0]], :] + z2 = input_latent[[indexbeta[0]], :] + + # for the categorial variable from onehot to categories + if diff_list is None: + diff_list = np.abs(z1 - z2) + else: + diff_list = np.append(diff_list, np.abs(z1 - z2), axis = 0) + + mean_diff = np.mean(diff_list, axis = 0, keepdims = 1) + mean_diff_subdim = mean_diff[:, latent_dim] + + std_diff = np.std(ckim_scale, axis = 0, keepdims = 1) + std_diff_subdim = std_diff[:, latent_dim] + std_diff_max = np.argmax(std_diff_subdim) + + if mean_list is None: + mean_list = mean_diff_subdim + else: + mean_list = np.append(mean_list, mean_diff_subdim, axis = 0) + + if std_list is None: + std_list = std_diff_max + else: + std_list = np.append(std_list, std_diff_max) + + k_list.append(k[0]) + + # classifier (cross-validation) + train_index = range(int(K_samples * 0.8)) + test_index = range(int(K_samples * 0.8), K_samples) + + #k_list_n = [j for sub in k_list for j in sub] # if [[0], [1], ...] is used for k_list + X_train, X_test = mean_list[train_index, :], mean_list[test_index, :] + y_train, y_test = np.array(k_list)[train_index], np.array(k_list)[test_index] + + # SVM classifier + self.clf.fit(X_train, y_train) + predictions_train = self.clf.predict(X_train) + predictions_test = self.clf.predict(X_test) + + betaVAE_train = np.mean((predictions_train == y_train)*1) + betaVAE_test = np.mean((predictions_test == y_test)*1) + + # FactorVAE + X_train, X_test = std_list[train_index], std_list[test_index] + y_train, y_test = np.array(k_list)[train_index], np.array(k_list)[test_index] + + # majority vote classifier + X_list, y_list = np.unique(X_train), np.unique(y_train) + v = np.zeros((len(X_list), len(y_list))) + for ind in range(len(X_train)): + row, col = X_train[ind], y_train[ind] + row_in, col_in = list(np.where(row == X_list)[0])[0], list(np.where(col == y_list)[0])[0] + v[row_in, col_in] += 1 + pre_model = np.argmax(v, axis = 1) + + # majority vote predictions + predictions_beta_train = np.zeros(y_train.shape) + predictions_beta_test = np.zeros(y_test.shape) + + for i in range(len(X_train)): + t = X_train[i] + if list(np.where(t == X_list)[0]) == []: + predictions_beta_train[i] = 1000 # an arbitrary large category + else: + predictions_beta_train[i] = pre_model[list(np.where(t == X_list)[0])[0]] + + factorVAE_train = np.mean((predictions_beta_train == y_train)*1) + + + for i in range(len(X_test)): + t = X_test[i] + if list(np.where(t == X_list)[0]) == []: + predictions_beta_test[i] = 1000 # an arbitrary large category + else: + predictions_beta_test[i] = pre_model[list(np.where(t == X_list)[0])[0]] + + factorVAE_test = np.mean((predictions_beta_test == y_test)*1) + + return betaVAE_train, betaVAE_test, factorVAE_train, factorVAE_test + + def KNNPredictVar(self, pca_real, pca_fake, GroundTruthVar_real): + """ + predict the ground-truth varable values of fake data based on the k-nearest neighbor algorithm + trained on real PC values. + """ + + neigh = KNeighborsClassifier(n_neighbors = 3) + neigh.fit(pca_real, GroundTruthVar_real) + + GroundTruthVar_fake = neigh.predict(pca_fake) + + return GroundTruthVar_fake + + + +class LatentSpaceVectorArithmetic: + """ + Latent space vector arithmetic algorithm + """ + def __init__(self): + super().__init__() + + def Union(self, lst1, lst2): + """ + union of two lists + """ + final_list = list(set(lst1) | set(lst2)) + return final_list + + def balancer(self, data, meta_data, consider_trt): + """ + balance data based on treatment + """ + + list_trt = list(meta_data['treatment']) + class_pop = {} + for cls in consider_trt: + class_pop[cls] = meta_data.copy()[meta_data['treatment'] == cls].shape[0] + + max_number = np.max(list(class_pop.values())) + all_data = None + meta_all = None + + for cls in consider_trt: + idx = [i for i in range(len(list_trt)) if list_trt[i] == cls] + + temp = data.copy()[idx, :] + temp_meta = meta_data.copy().iloc[idx, :] + + index = np.random.choice(range(temp.shape[0]), max_number) + temp_x = temp[index, :] + meta_x = temp_meta.iloc[index, :] + if all_data is None: + all_data = temp_x + meta_all = meta_x + else: + all_data = np.concatenate((all_data, temp_x), axis = 0) + meta_all = pd.concat([meta_all, meta_x]) + + return all_data, meta_all + + def sample_for_effect(self, data_1, data_2, meta_1, meta_2, consider_trt): + """ + give two datasets with different cell types, + balance each dataset by treatment type + balance two datasets in size + """ + data_1_b, meta_1_b = self.balancer(data_1, meta_1, consider_trt) + data_2_b, meta_2_b = self.balancer(data_2, meta_2, consider_trt) + + balance_size = min(data_1_b.shape[0], data_2_b.shape[0]) + idx_1 = np.random.choice(range(data_1_b.shape[0]), size = balance_size, replace = False) + idx_2 = np.random.choice(range(data_2_b.shape[0]), size = balance_size, replace = False) + + data_use1, data_use2 = data_1_b[idx_1, :], data_2_b[idx_2, :] + return data_use1, data_use2 + + + def latent_vector_arithmetic(self, tf_sess, tf_z_latent, tf_x_data, tf_training, x_data1, x_data2): + """ + calculate the averaged latent difference between two samples + """ + + feed_dict1 = {tf_x_data: x_data1, tf_training: False} + z_latent1 = tf_sess.run(tf_z_latent, feed_dict = feed_dict1) + + feed_dict2 = {tf_x_data: x_data2, tf_training: False} + z_latent2 = tf_sess.run(tf_z_latent, feed_dict = feed_dict2) + + z_diff = z_latent1.mean(0) - z_latent2(0) + + return z_diff + + def generate_from_latent(self, tf_sess, tf_z_latent, tf_x_rec_data, tf_training, z_data): + """ + generate data from latent samples + """ + + feed_dict = {tf_z_latent: z_data, tf_training: False} + + return tf_sess.run(tf_x_rec_data, feed_dict = feed_dict) + + def estimate(self, x_data1, x_data2, meta_1, meta_2, consider_trt, data_true, data_control, + tf_sess, tf_z_latent, tf_x_rec_data, tf_x_data, tf_training): + """ + latent space vector arithmetic + 1 is for control, + 2 is for target + """ + + x_data1, x_data2 = self.sample_for_effect(x_data1, x_data2, meta_1, meta_2, consider_trt) + z_diff = self.latent_vector_arithmetic(tf_sess, tf_z_latent, tf_x_data, x_data1, x_data2) + + n_true, n_control = data_true.shape[0], data_control.shape[0] + + # make sure the predicted data, true data and control data all have the same size + if n_true < n_control: + + sample_indices = sample(range(n_control), n_true) + input_control = data_control[sample_indices] + + elif n_true > n_control: + + sample_indices = np.random.choice(n_control, n_true, replace = True) + input_control = data_control[sample_indices] + + else: + + input_control = data_control + + feed_dict = {tf_x_data: input_control, tf_training: False} + z_control = tf_sess.run(tf_z_latent, feed_dict = feed_dict) + + z_pred = z_control - z_diff + x_pred_data = self.generate_from_latent(tf_sess, tf_z_latent, tf_x_rec_data, + z_pred) + + return x_pred_data, input_control \ No newline at end of file diff --git a/metrics/GenerationMetrics.py b/metrics/GenerationMetrics.py new file mode 100644 index 0000000..31dc248 --- /dev/null +++ b/metrics/GenerationMetrics.py @@ -0,0 +1,399 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import numpy as np +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import pandas as pd + +import umap +from sklearn.decomposition import PCA +from scipy.spatial.distance import pdist +from plotnine import * +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import KFold +from sklearn.metrics import roc_curve, auc +from sklearn.preprocessing import label_binarize + +from .util import * + + +class MetricVisualize: + """ + Generation measure class + """ + def __init__(self): + super().__init__ + self.pca_50 = PCA(n_components=50, random_state = 42) + self.rf = RandomForestClassifier(n_estimators = 1000, random_state=42) + + def CorrelationDistance(self, data): + """ + calculate correlation distance within a dataset + """ + return round(np.median(pdist(data, metric='correlation')), 3) + + + def FIDScore(self, real_data, fake_data, pca_data_fit = None, if_dataPC = False): + """ + calculate Frechet inception distance between real and fake data on the PC space + """ + + all_data = np.concatenate([fake_data, real_data], axis = 0) + + if if_data_PC: + pca_all = pca_data_fit.transform(all_data) + else: + pca_all = self.pca_50.fit(real_data).transform(all_data) + + pca_real, pca_fake = pca_all[fake_data.shape[0]:], pca_all[:fake_data.shape[0]] + + FIDval = calculate_fid_score(pca_fake, pca_real) + + return FIDval + + def InceptionScore(self, real_data, real_cell_type, target_data): + """ + calculate inception score of target data based on the cell type random forest classifier + on the real data + """ + + rf_fit = self.rf.fit(real_data, real_cell_type) + data_score = rf_fit.predict_proba(target_data) + + meanScore, stdScore = preds2score(data_score, data_score.mean(axis = 0), splits = 3) + + return meanScore, stdScore + + + + def umapPlot(self, real_data, fake_data, path_file_save = None): + """ + UMAP plot of real and fake data + """ + all_data = np.concatenate([fake_data, real_data], axis = 0) + pca_all = self.pca_50.fit(real_data).transform(all_data) + pca_result_real = pca_all[fake_data.shape[0]:] + + cat_t = ["1-Real"] * real_data.shape[0] + cat_g = ["2-Fake"] * fake_data.shape[0] + cat_rf_gt = np.append(cat_g, cat_t) + + trans = umap.UMAP(random_state=42, min_dist = 0.5, n_neighbors=30).fit(pca_result_real) + + X_embedded_pr = trans.transform(pca_all) + df_tsne_pr = X_embedded_pr.copy() + df_tsne_pr = pd.DataFrame(df_tsne_pr) + df_tsne_pr['x-umap'] = X_embedded_pr[:,0] + df_tsne_pr['y-umap'] = X_embedded_pr[:,1] + df_tsne_pr['category'] = cat_rf_gt + + chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ + + geom_point(size=0.5, alpha = 0.5) \ + + ggtitle("UMAP dimensions") + + if path_file_save is not None: + chart_pr.save(path_file_save, width=12, height=8, dpi=144) + + return chart_pr + + def umapPlotByCat(self, pca_data_fit, data, data_category, path_file_save = None): + """ + UMAP plot of data colored by categories. It involves a PCA procedure + """ + pca_data = pca_data_fit.transform(data) + + trans = umap.UMAP(random_state=42, min_dist = 0.5, n_neighbors=30).fit(pca_result_real) + + X_embedded_pr = trans.transform(pca_data) + df_tsne_pr = X_embedded_pr.copy() + df_tsne_pr = pd.DataFrame(df_tsne_pr) + df_tsne_pr['x-umap'] = X_embedded_pr[:,0] + df_tsne_pr['y-umap'] = X_embedded_pr[:,1] + df_tsne_pr['category'] = data_category + + chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ + + geom_point(size=0.5, alpha = 0.5) \ + + ggtitle("UMAP dimensions") + + if path_file_save is not None: + chart_pr.save(path_file_save, width=12, height=8, dpi=144) + + return chart_pr + + + def umapPlotPurelyByCat(self, umap_data, data_category, path_file_save = None): + """ + UMAP plot of data colored by categories. It directly has the UMAP data as an input. + """ + df_tsne_pr = umap_data.copy() + df_tsne_pr = pd.DataFrame(df_tsne_pr) + df_tsne_pr['x-umap'] = umap_data[:,0] + df_tsne_pr['y-umap'] = umap_data[:,1] + df_tsne_pr['category'] = data_category + + chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ + + geom_point(size=0.5, alpha = 0.5) \ + + ggtitle("UMAP dimensions") + + if path_file_save is not None: + chart_pr.save(path_file_save, width=12, height=8, dpi=144) + return chart_pr + + + def umapPlotPurelyByCatHighQuality(self, umap_data, xlab_showname, ylab_showname, data_category, + nrowlegend = 7, size = 5, alpha = 1, legend_title = 'UMAP Plot', + path_file_save = None): + """ + high-quality UMAP plot of umap data by categories. + """ + df_tsne_pr = umap_data.copy() + df_tsne_pr = pd.DataFrame(df_tsne_pr) + df_tsne_pr['x-umap'] = umap_data[:,0] + df_tsne_pr['y-umap'] = umap_data[:,1] + df_tsne_pr['category'] = data_category + + chart_pr = ggplot(df_tsne_pr, aes(x= 'x-umap', y= 'y-umap', colour = 'category') ) \ + + geom_point(size = size, alpha = alpha) + labs(x = xlab_showname, y = ylab_showname) \ + + geom_abline(intercept = 0 , slope = 1, size=1, linetype="dashed", color="black") \ + + xlim(0, 1) + ylim(0, 1) + theme_bw() \ + + theme(panel_background = element_rect(fill='white'), + title = element_text(size = 25), + axis_title_x = element_text(size = 25), + axis_title_y = element_text(size = 25), + axis_text_x = element_text(size = 15), + axis_text_y = element_text(size = 15), + legend_title = element_text(size = 20), + legend_text = element_text(size = 20), + axis_ticks_major_y = element_blank(), + axis_ticks_major_x = element_blank(), + panel_grid = element_blank()) \ + + ggtitle(legend_title) \ + + guides(colour = guide_legend(nrow=nrowlegend, override_aes={"size": 10})) + + if path_file_save is not None: + chart_pr.save(path_file_save, width=12, height=8, dpi=144) + return chart_pr + + def latentHistPlot(self, z_data, path_file_save = None): + """ + Plot of histograms + """ + + dict_use = {} + + for h in range(z_data.shape[1]): + + dict_use["Var " + str(h+1)] = h + 1 + + newfig = plt.figure(figsize=[20,16]) + for m in range(z_data.shape[1]): + name_i = list(dict_use.keys())[m] + num_i = dict_use[name_i] + ax1 = newfig.add_subplot(4, 3, t + 1) + weights = np.ones_like(z_data[:,m])/float(len(z_data[:,m])) + ax1.hist(z_data[:,m], bins = 100, weights = weights, alpha = 0.5) + ax1.set_title(name_i) + + if path_file_save is not None: + newfig.savefig(path_file_save) + + def latentColorPlot(self, z_data, umapData, path_file_save = None): + """ + UMAP plots by latent values + """ + + dict_use = {} + for h in range(z_data.shape[1]): + dict_use["Var " + str(h+1)] = h + 1 + + # mapped + newfig = plt.figure(figsize=[20,16]) + for m in range(len(dict_use)): + name_i = list(dict_use.keys())[m] + num_i = dict_use[name_i] + ax1 = newfig.add_subplot(4, 3,num_i) + cb1 = ax1.scatter(umapData['x-umap'], umapData['y-umap'], s= 1, c = z_data[:, m], cmap= "plasma") + ax1.set_title(name_i) + + if path_file_save is not None: + newfig.savefig(path_file_save) + + + + +class RandomForestError: + """ + Random forest class + """ + + def __init__(self, n_folds = 5): + super().__init__() + self.rf = RandomForestClassifier(n_estimators = 1000, random_state=42) + self.pca_50 = PCA(n_components=50, random_state = 42) + self.n_folds = n_folds + + def PrepareIndexes(self, pca_real, pca_fake): + """ + Indices to use for random forest classifier + """ + assert pca_real.shape[0] == pca_fake.shape[0] + self.num_realize_gen = pca_real.shape[0] + self.cat_t = ["1-training"] * self.num_realize_gen + self.cat_g = ["2-generated"] * self.num_realize_gen + self.cat_rf_gt = np.append(self.cat_g, self.cat_t) + + self.index_shuffle_mo = list(range(self.num_realize_gen + self.num_realize_gen)) + np.random.shuffle(self.index_shuffle_mo) + + self.cat_rf_gt_s = self.cat_rf_gt[self.index_shuffle_mo] + + + kf = KFold(n_splits = self.n_folds, random_state = 42) + + kf_cat_gt = kf.split(self.cat_rf_gt_s) + self.train_in = np.array([]) + self.test_in = np.array([]) + self.train_cluster_in = np.array([]) + self.test_cluster_in = np.array([]) + + j = 0 + for train_index, test_index in kf_cat_gt: + self.train_in = np.append([self.train_in], [train_index]) + self.test_in = np.append([self.test_in], [test_index]) + self.train_cluster_in = np.append(self.train_cluster_in, np.repeat(j, len(train_index))) + self.test_cluster_in = np.append(self.test_cluster_in, np.repeat(j, len(test_index)) ) + j+=1 + + + def fit(self, real_data, fake_data, pca_data_fit = None, if_dataPC = False, output_AUC = True, path_save = "."): + """ + fit a 5-fold random forest classifier on real and fake data within the PC space + """ + all_data = np.concatenate([fake_data, real_data], axis = 0) + if if_dataPC: + pca_all = pca_data_fit.transform(all_data) + else: + pca_all = self.pca_50.fit(real_data).transform(all_data) + pca_real, pca_fake = pca_all[fake_data.shape[0]:], pca_all[:fake_data.shape[0]] + self.PrepareIndexes(pca_real, pca_fake) + + pca_gen_s = pca_all[self.index_shuffle_mo] + + vari = pca_gen_s # generated + outc = self.cat_rf_gt_s + # Binarize the output + outc_1 = label_binarize(outc, classes=['', '1-training', '2-generated']) + outc_1 = outc_1[:, 1:] + n_classes = outc_1.shape[1] + outc = np.array(outc) + errors = np.array([]) + for j in range(self.n_folds): + train_index = [int(self.train_in[self.train_cluster_in == j][k]) for k in range(self.train_in[self.train_cluster_in == j].shape[0])] + test_index = [int(self.test_in[self.test_cluster_in == j][k]) for k in range(self.test_in[self.test_cluster_in == j].shape[0])] + X_train, X_test = vari[train_index], vari[test_index] + y_train, y_test = outc[train_index], outc[test_index] + y_test_1 = outc_1[test_index] + self.rf.fit(X_train, y_train) + predictions = self.rf.predict(X_test) + errors = np.append(errors, np.mean((predictions != y_test)*1)) + + if output_AUC: + # AUC plots + y_score_tr = self.rf.fit(X_train, y_train).predict_proba(X_test) + + fpr = dict() + tpr = dict() + roc_auc = dict() + for k in range(n_classes): + fpr[k], tpr[k], _ = roc_curve(y_test_1[:, k], y_score_tr[:, k]) + roc_auc[k] = auc(fpr[k], tpr[k]) + + # Compute micro-average ROC curve and ROC area + fpr["micro"], tpr["micro"], _ = roc_curve(y_test_1.ravel(), y_score_tr.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + newfig = plt.figure() + lw = 2 + plt.plot(fpr[1], tpr[1], color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[1]) + plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Receiver operating characteristic') + plt.legend(loc="lower right") + plt.savefig(os.path.join(path_save, "gen_" + str(j) + "_fold_result.png")) + plt.close(newfig) + + errors = np.append(errors, np.mean(errors)) + errors_pd = pd.DataFrame([errors], columns = ['1st', '2nd', '3rd' , '4th', '5th', 'avg']) + + + return errors_pd + + def fit_once(self, real_data, fake_data, pca_data_fit = None, if_dataPC = False, output_AUC = True, path_save = "."): + """ + fit a cross-validated random forest classifier on real and fake data within the PC space + """ + all_data = np.concatenate([fake_data, real_data], axis = 0) + if if_dataPC: + pca_all = pca_data_fit.transform(all_data) + else: + pca_all = self.pca_50.fit(real_data).transform(all_data) + + pca_real, pca_fake = pca_all[fake_data.shape[0]:], pca_all[:fake_data.shape[0]] + self.PrepareIndexes(pca_real, pca_fake) + + pca_gen_s = pca_all[self.index_shuffle_mo] + + vari = pca_gen_s # generated + outc = self.cat_rf_gt_s + # Binarize the output + outc_1 = label_binarize(outc, classes=['', '1-training', '2-generated']) + outc_1 = outc_1[:, 1:] + n_classes = outc_1.shape[1] + outc = np.array(outc) + + + j = 0 + + train_index = [int(self.train_in[self.train_cluster_in == j][k]) for k in range(self.train_in[self.train_cluster_in == j].shape[0])] + test_index = [int(self.test_in[self.test_cluster_in == j][k]) for k in range(self.test_in[self.test_cluster_in == j].shape[0])] + X_train, X_test = vari[train_index], vari[test_index] + y_train, y_test = outc[train_index], outc[test_index] + y_test_1 = outc_1[test_index] + self.rf.fit(X_train, y_train) + predictions = self.rf.predict(X_test) + errors = np.mean((predictions != y_test)*1) + + if output_AUC: + # AUC plots + y_score_tr = self.rf.fit(X_train, y_train).predict_proba(X_test) + + fpr = dict() + tpr = dict() + roc_auc = dict() + for k in range(n_classes): + fpr[k], tpr[k], _ = roc_curve(y_test_1[:, k], y_score_tr[:, k]) + roc_auc[k] = auc(fpr[k], tpr[k]) + + # Compute micro-average ROC curve and ROC area + fpr["micro"], tpr["micro"], _ = roc_curve(y_test_1.ravel(), y_score_tr.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + newfig = plt.figure() + lw = 2 + plt.plot(fpr[1], tpr[1], color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[1]) + plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Receiver operating characteristic') + plt.legend(loc="lower right") + plt.savefig(os.path.join(path_save, "gen_" + str(j) + "_fold_result.png")) + plt.close(newfig) + + + return errors + + diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/util.py b/metrics/util.py new file mode 100644 index 0000000..175ebe5 --- /dev/null +++ b/metrics/util.py @@ -0,0 +1,174 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import numpy as np +from random import sample +import tensorflow as tf +from scipy import linalg +from tensorflow import distributions as ds + + +def permute_dims(z, opt): + permuted_rows = [] + zuse = tf.identity(z, name="input") + cat_part = zuse[:, :opt.non_noise_cat] + cat_part_sf = tf.random_shuffle(cat_part) + for i in range(zuse.get_shape()[1]): + if i >= opt.non_noise_cat: + permuted_rows.append(tf.random_shuffle(zuse[:, i])) + permuted_samples = tf.stack(permuted_rows, axis=1) + permuted_samples = tf.concat([cat_part_sf, permuted_samples], axis = 1) + permuted_output = tf.identity(permuted_samples, name = "output") + return permuted_output + +def con_noise_prior(con_tensor, batch_size, dim): + rez = np.zeros([batch_size, dim]) + for t in range(con_tensor.shape[1]): + sam_index = sample(range(con_tensor.shape[0]), batch_size) + input_realize = con_tensor[sam_index, t] + rez[:, t] = input_realize + if dim > con_tensor.shape[1]: + rez[:, con_tensor.shape[1]:] = np.random.normal(0.0, scale = 1.0, size=(batch_size, dim - (con_tensor.shape[1]))) + return rez + + +def random_fix_prior(con_tensor, con_place, batch_size, dim, opt): + rez = np.zeros([batch_size, dim]) + for t in range(con_tensor.shape[1]): + sam_index = sample(range(con_tensor.shape[0]), batch_size) + input_realize = con_tensor[sam_index, t] + rez[:, t] = input_realize + if dim > con_tensor.shape[1]: + rez[:, con_tensor.shape[1]:] = np.random.normal(0.0, scale = 1.0, size=(batch_size, dim - (con_tensor.shape[1]))) + z_use = rez.copy() + h_data = z_use[:, opt.non_noise_cat + con_place] + + rez_2 = np.zeros([batch_size, dim]) + for t in range(con_tensor.shape[1]): + sam_index = sample(range(con_tensor.shape[0]), batch_size) + input_realize = con_tensor[sam_index, t] + rez_2[:, t] = input_realize + if dim > con_tensor.shape[1]: + rez_2[:, con_tensor.shape[1]:] = np.random.normal(0.0, scale = 1.0, size=(batch_size, dim - (con_tensor.shape[1]))) + rez_2[:, opt.non_noise_cat + con_place] = h_data + return rez, rez_2 + + +def fix_noise_prior(con_tensor, batch_size, dim): + input_realize = con_tensor + rez = np.zeros([batch_size, dim]) + if dim == (input_realize.shape[1]): + rez = input_realize + else: + rez[:, :(input_realize.shape[1])] = input_realize + rez[:, (input_realize.shape[1]):] = np.random.normal(0.0, scale = 1.0, size=(batch_size, dim - (input_realize.shape[1]))) + return rez + +def noise_prior(batch_size, dim): + temp_norm = np.random.normal(0.0, scale = 1.0, size=(batch_size, dim)) + return temp_norm + +def prior(batch_size, dim): + shp = [batch_size, dim] + loc = tf.zeros(shp) + scale = tf.ones(shp) + return ds.Normal(loc, scale) + +def random_uc(insize, opt): + idxs = np.random.randint(opt.non_noise_cat, size = insize) + onehot = np.zeros((insize, opt.non_noise_cat)) + onehot[np.arange(insize), idxs] = 1 + return onehot, idxs + +def random_z(size, opt): + rez = np.zeros([size, opt.noise_input_size]) + rez[:, :opt.non_noise_cat], idxs = random_uc(size, opt) + rez[:, opt.non_noise_cat:] = noise_prior(size, opt.noise_input_size - opt.non_noise_cat) + return rez, idxs + +def random_fix_z(y_data, opt): + rez = np.zeros([y_data.shape[0], opt.noise_input_size]) + rez[:, :opt.non_noise_cat] = y_data + rez[:, opt.non_noise_cat:] = noise_prior(y_data.shape[0], opt.noise_input_size - opt.non_noise_cat) + return rez + +# con_place equal to 0 or 1 for 2 continuous variables of ground truth +def random_fix_z_con(size, con_place, opt): + z_data, _ = random_z(size, opt) + z_use = z_data.copy() + h_data = z_use[:, opt.non_noise_cat + con_place] + z_data_2, _ = random_z(size, opt) + z_data_2[:, opt.non_noise_cat + con_place] = h_data + return z_data, z_data_2 + +def random_fix_noise_prior(size, con_place, opt): + z_data = noise_prior(size, opt.noise_input_size_2) + z_use = z_data.copy() + h_data = z_use[:, opt.non_noise_cat + con_place] + z_data_2 = noise_prior(size, opt.noise_input_size_2) + z_data_2[:, opt.non_noise_cat + con_place] = h_data + return z_data, z_data_2 + +def log(x, opt): + return tf.log(x + opt.epsilon_use) + +def sample_X(X, size): + start_idx = np.random.randint(0, X.shape[0] - size) + return X[start_idx:start_idx + size, :] + +def sample_XY(X, Y, size): + start_idx = np.random.randint(0, X.shape[0] - size) + return X[start_idx:start_idx + size, :], Y[start_idx:start_idx + size, :] + + +def preds2score(PYX, PY, eps = 1e-6, splits=3): + scores = [] + for i in range(splits): + part = PYX[(i * PYX.shape[0] // splits):((i + 1) * PYX.shape[0] // splits), :] + part = part + eps + kl = part * (np.log(part) - np.log(np.expand_dims(PY, 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return np.mean(scores), np.std(scores) + + +def generateTheta(L, ndim): + # This function generates L random samples from the unit `ndim'-u + theta=[w/np.sqrt((w**2).sum()) for w in np.random.normal(size=(L, ndim))] + return np.asarray(theta) + + + +def calculate_statistics(numpy_data): + mu = np.mean(numpy_data, axis = 0) + sigma = np.cov(numpy_data, rowvar = False) + return mu, sigma + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps = 1e-6): + diff = mu1 - mu2 + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp = False) + if not np.isfinite(covmean).all(): + msg = ( + 'fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates' % eps + ) + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol = 1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Cell component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean ) + + +def calculate_fid_score(data1, data2): + m1, s1 = calculate_statistics(data1) + m2, s2 = calculate_statistics(data2) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value diff --git a/models/Adam_prediction.py b/models/Adam_prediction.py new file mode 100755 index 0000000..a20cf3a --- /dev/null +++ b/models/Adam_prediction.py @@ -0,0 +1,225 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class Adam_Prediction_Optimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, prediction=False, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + Initialization: + ``` + m_0 <- 0 (Initialize initial 1st moment vector) + v_0 <- 0 (Initialize initial 2nd moment vector) + t <- 0 (Initialize timestep) + ``` + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + ``` + t <- t + 1 + lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) + m_t <- beta1 * m_{t-1} + (1 - beta1) * g + v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g + variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) + ``` + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + Args: + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(Adam_Prediction_Optimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._prediction = prediction + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + # Variables to accumulate the powers of the beta parameters. + # Created in _create_slots when we know the variables to optimize. + self._beta1_power = None + self._beta2_power = None + + # Created in SparseApply if needed. + self._updated_lr = None + + def _get_beta_accumulators(self): + return self._beta1_power, self._beta2_power + + def _non_slot_variables(self): + return self._get_beta_accumulators() + + def _create_slots(self, var_list): + # Create the beta1 and beta2 accumulators on the same device as the first + # variable. Sort the var_list to make sure this device is consistent across + # workers (these need to go on the same PS, otherwise some updates are + # silently ignored). + first_var = min(var_list, key=lambda x: x.name) + + create_new = self._beta1_power is None + if not create_new and context.in_graph_mode(): + create_new = (self._beta1_power.graph is not first_var.graph) + + if create_new: + with ops.colocate_with(first_var): + self._beta1_power = variable_scope.variable(self._beta1, + name="beta1_power", + trainable=False) + self._beta2_power = variable_scope.variable(self._beta2, + name="beta2_power", + trainable=False) + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon") + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.apply_adam( + var, m, v, + math_ops.cast(self._beta1_power, var.dtype.base_dtype), + math_ops.cast(self._beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(self._beta1_power, grad.dtype.base_dtype), + math_ops.cast(self._beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + + # Prediction + if self._prediction : # for G + var2 = 2 * (var - lr * m_t / (v_sqrt + epsilon_t)) + var_update = state_ops.assign_sub(var2, var, use_locking=self._use_locking) + else : + var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) + + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) + + def _finish(self, update_ops, name_scope): + # Update the power accumulators. + with ops.control_dependencies(update_ops): + with ops.colocate_with(self._beta1_power): + update_beta1 = self._beta1_power.assign( + self._beta1_power * self._beta1_t, + use_locking=self._use_locking) + update_beta2 = self._beta2_power.assign( + self._beta2_power * self._beta2_t, + use_locking=self._use_locking) + return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], + name=name_scope) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..836e3e8 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- diff --git a/models/cgan.py b/models/cgan.py new file mode 100644 index 0000000..e4bcaa8 --- /dev/null +++ b/models/cgan.py @@ -0,0 +1,581 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +from scipy import sparse +import pandas as pd + +import tensorflow as tf +from tensorflow import distributions as ds + +from .util import * +from .Adam_prediction import Adam_Prediction_Optimizer +from .gan import BaseGAN + +from metrics.GenerationMetrics import * +log = logging.getLogger(__file__) + + +class CWGAN_GP(BaseGAN): + + """ + Conditional Wasserstein GAN with gradient penalty (CWGAN-GP or PCWGAN-GP) + """ + def __init__(self, x_dimension, z_dimension = 10, noise_dimension = 118, **kwargs): + super().__init__(x_dimension, z_dimension, **kwargs) + + self.n_dim = noise_dimension + self.if_PCGAN = kwargs.get("if_PCGAN", True) + + with tf.device(self.device): + + self.x = tf.placeholder(tf.float32, shape = [None, self.x_dim], name = "data") + self.z = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent") + self.noise = tf.placeholder(tf.float32, shape = [None, self.n_dim], name = "latent_noise") + + self.create_network() + self.loss_function() + + config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + self.sess = tf.Session(config = config) + self.saver = tf.train.Saver(max_to_keep = 1) + self.init = tf.global_variables_initializer().run(session = self.sess) + + def generatorDropOut(self): + """ + generator with dropout layers + """ + with tf.variable_scope('generatorDropOut', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + ge_dense1 = tf.layers.dropout(ge_dense1, self.dropout_rate, training=self.is_training) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + ge_dense2 = tf.layers.dropout(ge_dense2, self.dropout_rate, training=self.is_training) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + ge_dense3 = tf.layers.dropout(ge_dense3, self.dropout_rate, training=self.is_training) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def generator(self): + """ + generator without dropout layers + """ + with tf.variable_scope('generator', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def discriminatorPCGAN(self, x_input, z_input): + """ + discriminator of PCWGAN-GP + """ + with tf.variable_scope('discriminatorPCGAN', reuse = tf.AUTO_REUSE): + + disc_dense1 = tf.layers.dense(inputs= x_input, units= self.disc_internal_size1, activation = None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense1 = tf.layers.batch_normalization(disc_dense1, training = self.is_training) + disc_dense1 = tf.nn.leaky_relu(disc_dense1) + + disc_dense2 = tf.layers.dense(inputs = disc_dense1, units= self.disc_internal_size2, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense2 = tf.layers.batch_normalization(disc_dense2, training = self.is_training) + disc_dense2 = tf.nn.leaky_relu(disc_dense2) + + disc_dense3 = tf.layers.dense(inputs=disc_dense2, units= self.disc_internal_size3, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = tf.contrib.layers.xavier_initializer()) + disc_dense3 = tf.layers.batch_normalization(disc_dense3, training = self.is_training) + disc_dense3 = tf.nn.relu(disc_dense3) + + disc_dense4 = tf.layers.dense(inputs=disc_dense3, units=1,activation=None) + + disc_output = disc_dense4 + tf.reduce_sum(z_input * disc_dense3, axis = 1, keepdims = True) + + return disc_output, disc_dense3 + + def discriminatorCGAN(self, x_input, z_input): + """ + discriminator of CWGAN-GP + """ + with tf.variable_scope('discriminatorCGAN', reuse = tf.AUTO_REUSE): + + xz_input = tf.concat([x_input, z_input], axis = 1) + + disc_dense1 = tf.layers.dense(inputs= xz_input, units= self.disc_internal_size1, activation = None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense1 = tf.layers.batch_normalization(disc_dense1, training = self.is_training) + disc_dense1 = tf.nn.leaky_relu(disc_dense1) + + disc_dense2 = tf.layers.dense(inputs = disc_dense1, units= self.disc_internal_size2, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense2 = tf.layers.batch_normalization(disc_dense2, training = self.is_training) + disc_dense2 = tf.nn.leaky_relu(disc_dense2) + + disc_dense3 = tf.layers.dense(inputs=disc_dense2, units= self.disc_internal_size3, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense3 = tf.layers.batch_normalization(disc_dense3, training = self.is_training) + disc_dense3 = tf.nn.relu(disc_dense3) + + disc_output = tf.layers.dense(inputs=disc_dense3, units=1,activation=None) + + return disc_output, disc_dense3 + + + def create_network(self): + """ + construct the networks + """ + + if self.if_dropout: + self.x_gen_data = self.generatorDropOut() + else: + self.x_gen_data = self.generator() + + if self.if_PCGAN: + self.Dx_real, self.Dx_real_hidden = self.discriminatorPCGAN(self.x, self.z) + self.Dx_fake, self.Dx_fake_hidden = self.discriminatorPCGAN(self.x_gen_data, self.z) + else: + self.Dx_real, self.Dx_real_hidden = self.discriminatorCGAN(self.x, self.z) + self.Dx_fake, self.Dx_fake_hidden = self.discriminatorCGAN(self.x_gen_data, self.z) + + + def compute_gp(self, x, x_gen_data, z, discriminator): + """ + gradient penalty of discriminator + """ + epsilon_x = tf.random_uniform([], 0.0, 1.0) + x_hat = x * epsilon_x + (1 - epsilon_x) * x_gen_data + d_hat, _ = discriminator(x_hat, z) + + gradients = tf.gradients(d_hat, x_hat)[0] + slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) + + gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2) + + return gradient_penalty + + def loss_function(self): + """ + loss function + """ + + D_raw_loss = tf.reduce_mean(self.Dx_real) - tf.reduce_mean(self.Dx_fake) + self.G_loss = tf.reduce_mean(self.Dx_fake) + + if self.if_PCGAN: + self.gradient_penalty = self.compute_gp(self.x, self.x_gen_data, self.z, self.discriminatorPCGAN) + else: + self.gradient_penalty = self.compute_gp(self.x, self.x_gen_data, self.z, self.discriminatorCGAN) + + self.D_loss = D_raw_loss + self.lamb_gp * self.gradient_penalty + + tf_vars_all = tf.trainable_variables() + if self.if_PCGAN: + dvars = [var for var in tf_vars_all if var.name.startswith("discriminatorPCGAN")] + else: + dvars = [var for var in tf_vars_all if var.name.startswith("discriminatorCGAN")] + + if self.if_dropout: + gvars = [var for var in tf_vars_all if var.name.startswith("generatorDropOut")] + else: + gvars = [var for var in tf_vars_all if var.name.startswith("generator")] + + self.parameter_count = tf.reduce_sum( + [tf.reduce_prod(tf.shape(v)) for v in dvars + gvars]) + + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.g_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=True).minimize(self.G_loss, var_list = gvars) + self.d_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=False).minimize(self.D_loss, var_list = dvars) + + + @property + def model_parameter(self): + """ + report the number of training parameters + """ + self.total_param = self.sess.run(self.parameter_count) + return "There are {} parameters in conditional WGAN-GP.".format(self.total_param) + + def generate_cells(self, z_data): + """ + generate data from latent samples + """ + noise_data = self.sample_z(len(z_data), self.n_dim) + gen_data = self.sess.run(self.x_gen_data, feed_dict = {self.z: z_data, self.noise: noise_data, self.is_training: False}) + return gen_data + + + def restore_model(self, model_path): + """ + restore model from model_path + """ + self.saver.restore(self.sess, model_path) + + + def save_model(self, model_save_path, epoch): + """ + save the trained model to the model_save_path + """ + os.makedirs(model_save_path, exist_ok = True) + model_save_name = os.path.join(model_save_path, "model") + save_path = self.saver.save(self.sess, model_save_name, global_step = epoch) + + np.save(os.path.join(model_save_path, "training_time.npy"), self.training_time) + np.save(os.path.join(model_save_path, "train_loss_D.npy"), self.train_loss_D) + np.save(os.path.join(model_save_path, "train_loss_G.npy"), self.train_loss_G) + np.save(os.path.join(model_save_path, "valid_loss_D.npy"), self.valid_loss_D) + np.save(os.path.join(model_save_path, "valid_loss_G.npy"), self.valid_loss_G) + + def train(self, train_data, train_cond, use_validation = False, valid_data = None, valid_cond = None, use_test_during_train = False, test_data = None, + test_cond = None, test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train conditional WGAN-GP with train_data (AnnData), train_cond (numpy) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + if sparse.issparse(train_data.X): + pca_data_fit = pca_data_50.fit(train_data.X.A) + else: + pca_data_fit = pca_data_50.fit(train_data.X) + + train_data_copy = train_data.copy() + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + train_data, train_cond = shuffle_adata_cond(train_data, train_cond) + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb, z_mb = self.sample_data_cond(train_data, train_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb, z_mb = self.sample_data_cond(train_data, train_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb, z_mb = self.sample_data_cond(valid_data, valid_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + if sparse.issparse(train_data_copy.X): + test_data = train_data_copy[sampled_indices, :].X.A + else: + test_data = train_data_copy[sampled_indices, :].X + + test_cond = train_cond[sampled_indices, :] + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + test_cond = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + def train_np(self, train_data, train_cond, use_validation = False, valid_data = None, valid_cond = None, use_test_during_train = False, test_data = None, + test_cond = None, test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train conditional WGAN-GP with train_data (numpy), train_cond (numpy) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + train_cond = train_cond[index_shuffle] + + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb, z_mb = self.sample_data_cond_np(train_data, train_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb, z_mb = self.sample_data_cond_np(train_data, train_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb, z_mb = self.sample_data_cond_np(valid_data, valid_cond, batch_size) + n_mb = self.sample_z(batch_size, self.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + test_cond = train_cond[sampled_indices, :] + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + test_cond = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) \ No newline at end of file diff --git a/models/gan.py b/models/gan.py new file mode 100644 index 0000000..78d83ca --- /dev/null +++ b/models/gan.py @@ -0,0 +1,657 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +from scipy import sparse +import pandas as pd + +import tensorflow as tf +from tensorflow import distributions as ds + +from .util import * +from .Adam_prediction import Adam_Prediction_Optimizer +from metrics.GenerationMetrics import * +log = logging.getLogger(__file__) + + +class BaseGAN: + """ + Basic GAN class + """ + def __init__(self, x_dimension, z_dimension = 10, **kwargs): + tf.compat.v1.reset_default_graph() + self.x_dim = x_dimension + self.z_dim = z_dimension + self.learning_rate = kwargs.get("learning_rate", 5e-4) + self.dropout_rate = kwargs.get("dropout_rate", 0.2) + self.fix_std = kwargs.get("fix_std", True) + self.lamb_gp = kwargs.get("lambda_gp", 10.0) + self.alpha = kwargs.get("alpha", 1.0) + self.sample_c = kwargs.get('sample_c', 1) + self.Diters = kwargs.get('Diters', 5) + self.inflate_to_size1 = kwargs.get("inflate_size_1", 256) + self.inflate_to_size2 = kwargs.get("inflate_size_2", 512) + self.inflate_to_size3 = kwargs.get("inflate_size_3", 1024) + self.disc_internal_size1 = kwargs.get("disc_size_1", 1024) + self.disc_internal_size2 = kwargs.get("disc_size_2", 512) + self.disc_internal_size3 = kwargs.get("disc_size_3", 10) + self.if_dropout = kwargs.get("if_dropout", True) + self.if_BNTrainingMode = kwargs.get("BNTrainingMode", True) + self.is_training = tf.placeholder(tf.bool, name = "training_flag") + self.init_w = tf.contrib.layers.xavier_initializer() + self.regu_w = tf.contrib.layers.l2_regularizer(scale=0.8) + self.device = kwargs.get("device", '/device:GPU:0') + + self.train_loss_D = [] + self.train_loss_G = [] + self.valid_loss_D = [] + self.valid_loss_G = [] + self.training_time = 0.0 + + + def sample_z(self, batch_size, z_dim): + """ + sample the standard normal noises + """ + return np.random.normal(0.0, scale = 1.0, size = (batch_size, z_dim)) + + def sample_data(self, data, batch_size): + """ + sample data from AnnData datatype + """ + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + if sparse.issparse(data.X): + x_mb = data[lower:upper, :].X.A + else: + x_mb = data[lower:upper, :].X + return x_mb + + + def sample_data_np(self, data, batch_size): + """ + sample data from numpy array datatype + """ + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + + return data[lower:upper] + + def sample_data_cond(self, data, cond, batch_size): + """ + sample data from AnnData datatype along with its labels (numpy array) + """ + assert data.shape[0] == cond.shape[0] + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + + if sparse.issparse(data.X): + x_mb = data[lower:upper, :].X.A + else: + x_mb = data[lower:upper, :].X + + cond_mb = cond[lower:upper, :] + + return x_mb, cond_mb + + def sample_data_cond_np(self, data, cond, batch_size): + """ + sample data from numpy array datatype along with its labels (numpy array) + """ + assert data.shape[0] == cond.shape[0] + + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + + return data[lower:upper, :], cond[lower:upper, :] + + + +class WGAN_GP(BaseGAN): + + """ + Wasserstein GAN with gradient penalty (WGAN-GP) + """ + def __init__(self, x_dimension, z_dimension = 10, **kwargs): + super().__init__(x_dimension, z_dimension, **kwargs) + + with tf.device(self.device): + + self.x = tf.placeholder(tf.float32, shape = [None, self.x_dim], name = "data") + self.z = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent") + self.create_network() + self.loss_function() + + config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + self.sess = tf.Session(config = config) + self.saver = tf.train.Saver(max_to_keep = 1) + self.init = tf.global_variables_initializer().run(session = self.sess) + + + def generatorDropOut(self): + """ + generator with dropout layers of WGAN-GP + """ + with tf.variable_scope('generatorDropOut', reuse = tf.AUTO_REUSE): + + ge_dense1 = tf.layers.dense(inputs = self.z, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + ge_dense1 = tf.layers.dropout(ge_dense1, self.dropout_rate, training=self.is_training) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + ge_dense2 = tf.layers.dropout(ge_dense2, self.dropout_rate, training=self.is_training) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + ge_dense3 = tf.layers.dropout(ge_dense3, self.dropout_rate, training=self.is_training) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def generator(self): + """ + generator without dropout layers of WGAN-GP + """ + with tf.variable_scope('generator', reuse = tf.AUTO_REUSE): + ge_dense1 = tf.layers.dense(inputs = self.z, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def discriminator(self, x_input): + """ + discriminator of WGAN-GP + """ + with tf.variable_scope('discriminator', reuse = tf.AUTO_REUSE): + disc_dense1 = tf.layers.dense(inputs= x_input, units= self.disc_internal_size1, activation = None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense1 = tf.layers.batch_normalization(disc_dense1, training = self.is_training) + disc_dense1 = tf.nn.leaky_relu(disc_dense1) + + disc_dense2 = tf.layers.dense(inputs = disc_dense1, units= self.disc_internal_size2, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense2 = tf.layers.batch_normalization(disc_dense2, training = self.is_training) + disc_dense2 = tf.nn.leaky_relu(disc_dense2) + + disc_dense3 = tf.layers.dense(inputs=disc_dense2, units= self.disc_internal_size3, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense3 = tf.layers.batch_normalization(disc_dense3, training = self.is_training) + disc_dense3 = tf.nn.relu(disc_dense3) + + disc_output = tf.layers.dense(inputs=disc_dense3, units=1,activation=None) + return disc_output, disc_dense3 + + + def Q_mutual_info_network(self, disc_dense3): + """ + Q network of Q(C|X) to get H(C|G(C)) = E_{X = G(C), C~P(C)} [logQ(C|X)] + """ + with tf.variable_scope('mutual_info_bound', reuse = tf.AUTO_REUSE): + q_dense1 = tf.layers.dense(inputs = disc_dense3, units = self.disc_internal_size3, activation = None, + kernel_initializer = self.init_w) + q_dense1 = tf.layers.batch_normalization(q_dense1, training = self.is_training) + q_dense1 = tf.nn.leaky_relu(q_dense1) + + q_output = tf.layers.dense(inputs=q_dense1, units= (self.z_dim if self.fix_std else self.z_dim * 2), activation=None) + + return q_output + + def c_mutual_sample(self, c_vector): + """ + function to sample the C from q(C|X). For now, we only consider continuous + representations/latent variables + """ + + if self.fix_std: + mean_vec = c_vector + std_vec = tf.ones_like(mean_vec) + else: + mean_vec = c_vector[:, :self.z_dim] + std_vec = c_vector[:, self.z_dim:(self.z_dim * 2)] + std_vec = tf.nn.softplus(std_vec) + dist_c_vector = ds.Normal(mean_vec, std_vec) + c_gen = dist_c_vector.sample(self.sample_c) + c_gen = tf.reshape(c_gen,tf.shape(c_gen)[1:]) + return c_gen + + def create_network(self): + """ + construct the WGAN-GP networks + """ + if self.if_dropout: + self.x_gen_data = self.generatorDropOut() + else: + self.x_gen_data = self.generator() + + self.Dx_real, self.Dx_real_hidden = self.discriminator(self.x) + self.Dx_fake, self.Dx_fake_hidden = self.discriminator(self.x_gen_data) + self.c_mutual_fake = self.Q_mutual_info_network(self.Dx_fake_hidden) + self.c_mutual_real = self.Q_mutual_info_network(self.Dx_real_hidden) + self.c_gen_fake = self.c_mutual_sample(self.c_mutual_fake) + self.c_gen_real = self.c_mutual_sample(self.c_mutual_real) + + def compute_gp(self, x, x_gen_data, discriminator): + """ + gradient penalty of discriminator + """ + epsilon_x = tf.random_uniform([], 0.0, 1.0) + x_hat = x * epsilon_x + (1 - epsilon_x) * x_gen_data + + d_hat, _ = discriminator(x_hat) + gradients = tf.gradients(d_hat, x_hat)[0] + + slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) + gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2) + + return gradient_penalty + + def loss_function(self): + """ + loss function of WGAN-GP + """ + + D_raw_loss = tf.reduce_mean(self.Dx_real) - tf.reduce_mean(self.Dx_fake) + self.G_loss = tf.reduce_mean(self.Dx_fake) + self.gradient_penalty = self.compute_gp(self.x, self.x_gen_data, self.discriminator) + self.D_loss = D_raw_loss + self.lamb_gp * self.gradient_penalty + + tf_vars_all = tf.trainable_variables() + + dvars = [var for var in tf_vars_all if var.name.startswith("discriminator")] + + if self.if_dropout: + gvars = [var for var in tf_vars_all if var.name.startswith("generatorDropOut")] + else: + gvars = [var for var in tf_vars_all if var.name.startswith("generator")] + + self.parameter_count = tf.reduce_sum( + [tf.reduce_prod(tf.shape(v)) for v in dvars + gvars]) + + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.g_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=True).minimize(self.G_loss, var_list = gvars) + self.d_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=False).minimize(self.D_loss, var_list = dvars) + + @property + def model_parameter(self): + """ + report the number of training parameters + """ + self.total_param = self.sess.run(self.parameter_count) + return "There are {} parameters in WGAN-GP.".format(self.total_param) + + def generate_cells(self, z_data): + """ + generate data from latent samples + """ + gen_data = self.sess.run(self.x_gen_data, feed_dict = {self.z: z_data, self.is_training: False}) + return gen_data + + + def restore_model(self, model_path): + """ + restore model from model_path + """ + self.saver.restore(self.sess, model_path) + + + def save_model(self, model_save_path, epoch): + """ + save the trained model to the model_save_path + """ + os.makedirs(model_save_path, exist_ok = True) + model_save_name = os.path.join(model_save_path, "model") + save_path = self.saver.save(self.sess, model_save_name, global_step = epoch) + + np.save(os.path.join(model_save_path, "training_time.npy"), self.training_time) + np.save(os.path.join(model_save_path, "train_loss_D.npy"), self.train_loss_D) + np.save(os.path.join(model_save_path, "train_loss_G.npy"), self.train_loss_G) + np.save(os.path.join(model_save_path, "valid_loss_D.npy"), self.valid_loss_D) + np.save(os.path.join(model_save_path, "valid_loss_G.npy"), self.valid_loss_G) + + def train(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train WGAN-GP with train_data (AnnData) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + if sparse.issparse(train_data.X): + pca_data_fit = pca_data_50.fit(train_data.X.A) + else: + pca_data_fit = pca_data_50.fit(train_data.X) + + train_data_copy = train_data.copy() + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + train_data = shuffle_adata(train_data) + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb = self.sample_data(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb = self.sample_data(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + if sparse.issparse(train_data_copy.X): + test_data = train_data_copy[sampled_indices, :].X.A + else: + test_data = train_data_copy[sampled_indices, :].X + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + + + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + + def train_np(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train WGAN-GP with train_data (numpy array) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb = self.sample_data_np(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb = self.sample_data_np(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data_np(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + + \ No newline at end of file diff --git a/models/gaussianmixture.py b/models/gaussianmixture.py new file mode 100644 index 0000000..735a652 --- /dev/null +++ b/models/gaussianmixture.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import numpy as np +import pandas as pd +from sklearn import mixture + +class GaussianMixtureModel: + """ + Gaussian mixture model + """ + def __init__(self): + super().__init__() + + def GMModel(self, n_components, covariance_type): + """ + sklearn GMM + """ + gmm = mixture.GaussianMixture(n_components = n_components, + covariance_type = covariance_type) + + return gmm + + def SelectTrain(self, train_data, use_validation = False, valid_data = None, + n_components_range = range(1, 101), + cv_types = ['spherical', 'tied', 'diag', 'full']): + """ + cross validation to select n_components and covariance_type + """ + bic_train = {} + bic_valid = {} + + for covariance_type in cv_types: + + bic_train[covariance_type] = [] + bic_valid[covariance_type] = [] + + for n_components in n_components_range: + + gmm = self.GMModel(n_components, covariance_type) + + try: + gmm.fit(train_data) + bic_train[covariance_type].append(gmm.bic(train_data)) + + if use_validation: + bic_valid[covariance_type].append(gmm.bic(valid_data)) + + except: + bic_train[covariance_type].append(0) + + if use_validation: + bic_valid[covariance_type].append(0) + + + + + train_loss_df = pd.DataFrame(bic_train) + valid_loss_df = pd.DataFrame(bic_valid) + + return train_loss_df, valid_loss_df + + + def fit(self, train_data, n_components, covariance_type, use_validation = False, valid_data = None): + """ + fit GMM model with training data + """ + gmm = self.GMModel(n_components, covariance_type) + gmm.fit(train_data) + + + train_loss = gmm.bic(train_data) + valid_loss = 0.0 + if use_validation: + valid_loss = gmm.bic(valid_data) + + return gmm, train_loss, valid_loss + + def reconstruct(self, gmm, x_data): + """ + generate data from GMM + """ + + x_rec_data, rec_label = gmm.sample(x_data.shape[0]) + + return x_rec_data, rec_label + + diff --git a/models/infogan.py b/models/infogan.py new file mode 100644 index 0000000..5a043df --- /dev/null +++ b/models/infogan.py @@ -0,0 +1,589 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +from scipy import sparse +import pandas as pd + +import tensorflow as tf +from tensorflow import distributions as ds + +from .util import * +from .Adam_prediction import Adam_Prediction_Optimizer +from .gan import BaseGAN + +from metrics.GenerationMetrics import * +log = logging.getLogger(__file__) + + + +class InfoWGAN_GP(BaseGAN): + + """ + Information Maximizing Wasserstein GAN with gradient penalty (InfoWGAN-GP) + """ + def __init__(self, x_dimension, z_dimension = 10, noise_dimension = 118, **kwargs): + + super().__init__(x_dimension, z_dimension, **kwargs) + + self.n_dim = noise_dimension + self.noise = tf.placeholder(tf.float32, shape = [None, self.n_dim], name = "latent_noise") + self.lamb_mi_fake = kwargs.get("lamb_mi_fake", 1.0) + + with tf.device(self.device): + + self.x = tf.placeholder(tf.float32, shape = [None, self.x_dim], name = "data") + self.z = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent") + self.noise = tf.placeholder(tf.float32, shape = [None, self.n_dim], name = "latent_noise") + + self.create_network() + self.loss_function() + + config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + self.sess = tf.Session(config = config) + self.saver = tf.train.Saver(max_to_keep = 1) + self.init = tf.global_variables_initializer().run(session = self.sess) + + + def generatorDropOut(self): + """ + generator with dropout layers of InfoWGAN-GP + """ + with tf.variable_scope('generatorDropOut', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + ge_dense1 = tf.layers.dropout(ge_dense1, self.dropout_rate, training=self.is_training) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + ge_dense2 = tf.layers.dropout(ge_dense2, self.dropout_rate, training=self.is_training) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + ge_dense3 = tf.layers.dropout(ge_dense3, self.dropout_rate, training=self.is_training) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def generator(self): + """ + generator without dropout layers of InfoWGAN-GP + """ + with tf.variable_scope('generator', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + + def discriminatorInfo(self, x_input): + """ + discriminator of InfoWGAN-GP + """ + with tf.variable_scope('discriminatorInfo', reuse = tf.AUTO_REUSE): + + disc_dense1 = tf.layers.dense(inputs= x_input, units= self.disc_internal_size1, activation = None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense1 = tf.layers.batch_normalization(disc_dense1, training = self.is_training) + disc_dense1 = tf.nn.leaky_relu(disc_dense1) + + disc_dense2 = tf.layers.dense(inputs = disc_dense1, units= self.disc_internal_size2, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense2 = tf.layers.batch_normalization(disc_dense2, training = self.is_training) + disc_dense2 = tf.nn.leaky_relu(disc_dense2) + + disc_dense3 = tf.layers.dense(inputs=disc_dense2, units= self.disc_internal_size3, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense3 = tf.layers.batch_normalization(disc_dense3, training = self.is_training) + disc_dense3 = tf.nn.relu(disc_dense3) + + disc_output = tf.layers.dense(inputs=disc_dense3, units= 1, activation=None, + kernel_initializer = self.init_w) + + # Q network part + q_dense1 = tf.layers.dense(inputs = disc_dense3, units = self.disc_internal_size3, activation = None, + kernel_initializer = self.init_w) + q_dense1 = tf.layers.batch_normalization(q_dense1, training = self.is_training) + q_dense1 = tf.nn.leaky_relu(q_dense1) + + q_output = tf.layers.dense(inputs = q_dense1, + units= (self.z_dim if self.fix_std else self.z_dim * 2), + activation = None) + + + return disc_output, q_output + + def c_mutual_info(self, c_vector, z_sample): + """ + function to compute the mutual information lower bound + """ + c_sample = z_sample[:, :self.z_dim] + + if self.fix_std: + mean_vec = c_vector + std_vec = tf.ones_like(mean_vec) + else: + mean_vec = c_vector[:, self.z_dim] + std_vec = c_vector[:, self.z_dim:(self.z_dim * 2)] + std_vec = tf.nn.softplus(std_vec) + dist_c_vector = ds.Normal(mean_vec, std_vec) + + ll_logp = dist_c_vector.log_prob(c_sample) + ll_logp_sum = tf.reduce_sum(ll_logp, [1]) + + mi_bound = tf.reduce_mean(ll_logp_sum) + + return mi_bound + + def create_network(self): + """ + construct the InfoWGAN-GP networks + """ + + if self.if_dropout: + self.x_gen_data = self.generatorDropOut() + else: + self.x_gen_data = self.generator() + + self.Dx_real, self.c_mutual_real = self.discriminatorInfo(self.x) + self.Dx_fake, self.c_mutual_fake = self.discriminatorInfo(self.x_gen_data) + + self.q_vector = self.c_mutual_info(self.c_mutual_fake, self.z) + self.q_fake_mutual = tf.reduce_mean(self.q_vector) + + def compute_gp(self, x, x_gen_data, discriminator): + """ + gradient penalty of discriminator + """ + epsilon_x = tf.random_uniform([], 0.0, 1.0) + x_hat = x * epsilon_x + (1 - epsilon_x) * x_gen_data + d_hat, _ = discriminator(x_hat) + + gradients = tf.gradients(d_hat, x_hat)[0] + slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) + + gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2) + + return gradient_penalty + + def loss_function(self): + """ + loss function of InfoWGAN-GP + """ + + D_raw_loss = tf.reduce_mean(self.Dx_real) - tf.reduce_mean(self.Dx_fake) + self.G_loss = tf.reduce_mean(self.Dx_fake) - (self.q_fake_mutual * self.lamb_mi_fake) + + self.gradient_penalty = self.compute_gp(self.x, self.x_gen_data, self.discriminatorInfo) + self.D_loss = D_raw_loss + self.lamb_gp * self.gradient_penalty - (self.q_fake_mutual * self.lamb_mi_fake) + + tf_vars_all = tf.trainable_variables() + dvars = [var for var in tf_vars_all if var.name.startswith("discriminatorInfo")] + + if self.if_dropout: + gvars = [var for var in tf_vars_all if var.name.startswith("generatorDropOut")] + else: + gvars = [var for var in tf_vars_all if var.name.startswith("generator")] + + self.parameter_count = tf.reduce_sum( + [tf.reduce_prod(tf.shape(v)) for v in dvars + gvars]) + + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.g_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=True).minimize(self.G_loss, var_list = gvars) + self.d_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=False).minimize(self.D_loss, var_list = dvars) + + @property + def model_parameter(self): + """ + report the number of training parameters + """ + self.total_param = self.sess.run(self.parameter_count) + return "There are {} parameters in InfoWGAN-GP.".format(self.total_param) + + + def generate_cells(self, z_data): + """ + generate data from latent samples + """ + noise_data = self.sample_z(len(z_data), self.n_dim) + gen_data = self.sess.run(self.x_gen_data, feed_dict = {self.z: z_data, + self.noise: noise_data, self.is_training: False}) + return gen_data + + + def restore_model(self, model_path): + """ + restore model from model_path + """ + self.saver.restore(self.sess, model_path) + + + def save_model(self, model_save_path, epoch): + """ + save the trained model to the model_save_path + """ + os.makedirs(model_save_path, exist_ok = True) + model_save_name = os.path.join(model_save_path, "model") + save_path = self.saver.save(self.sess, model_save_name, global_step = epoch) + + np.save(os.path.join(model_save_path, "training_time.npy"), self.training_time) + np.save(os.path.join(model_save_path, "train_loss_D.npy"), self.train_loss_D) + np.save(os.path.join(model_save_path, "train_loss_G.npy"), self.train_loss_G) + np.save(os.path.join(model_save_path, "valid_loss_D.npy"), self.valid_loss_D) + np.save(os.path.join(model_save_path, "valid_loss_G.npy"), self.valid_loss_G) + + def train(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train InfoWGAN-GP with train_data (AnnData) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + if sparse.issparse(train_data.X): + pca_data_fit = pca_data_50.fit(train_data.X.A) + else: + pca_data_fit = pca_data_50.fit(train_data.X) + + train_data_copy = train_data.copy() + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + train_data = shuffle_adata(train_data) + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb = self.sample_data(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb = self.sample_data(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + if sparse.issparse(train_data_copy.X): + test_data = train_data_copy[sampled_indices, :].X.A + else: + test_data = train_data_copy[sampled_indices, :].X + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + + + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + + def train_np(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train InfoWGAN-GP with train_data (numpy array) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb = self.sample_data_np(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb = self.sample_data_np(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data_np(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, self.noise: n_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + z_data = self.sample_z(test_size, self.z_dim) + gen_data = self.generate_cells(z_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) diff --git a/models/michigan.py b/models/michigan.py new file mode 100644 index 0000000..4c0137e --- /dev/null +++ b/models/michigan.py @@ -0,0 +1,267 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +from scipy import sparse +import pandas as pd + +from .cgan import * +from .ppca import * +from .vae import * + +log = logging.getLogger(__file__) + + +class MichiGAN: + """ + MichiGAN class: + DisentangleMethod is pre-trained PCA, VAE or beta-TCVAE + GenerationMethod is untrained GANs + """ + def __init__(self, DisentangleMethod, GenerationMethod): + super().__init__() + self.DisentangleMethod = DisentangleMethod + self.GenerationMethod = GenerationMethod + + + def train_mean(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train GenerationMethod with latent representation conditions of posterior means + train_data (AnnData) and optional valid_data (numpy array) for n_epochs. + """ + if sparse.issparse(train_data.X): + train_dense = train_data.X.A + else: + train_dense = train_data.X + + train_cond = self.DisentangleMethod.encode_mean(train_dense) + + if use_validation: + if sparse.issparse(valid_data.X): + valid_dense = valid_data.X.A + else: + valid_dense = valid_data.X + valid_cond = self.DisentangleMethod.encode_mean(valid_dense) + else: + valid_cond = None + + if test_data is not None: + test_cond = self.DisentangleMethod.encode_mean(test_cond) + else: + test_cond = None + + self.GenerationMethod.train(train_data = train_data, train_cond = train_cond, use_validation = use_validation, + valid_data = valid_data, valid_cond = valid_cond, use_test_during_train = use_test_during_train, + test_data = test_data, test_cond = test_cond, test_every_n_epochs = test_every_n_epochs, + test_size = test_size, inception_score_data = inception_score_data, n_epochs = n_epochs, + batch_size = batch_size, early_stop_limit = early_stop_limit, threshold = threshold, + shuffle = shuffle, save = save, model_save_path = model_save_path, output_save_path = output_save_path, + verbose = verbose) + + def train_mean_np(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train GenerationMethod with latent representation conditions of posterior means + train_data (numpy array) and optional valid_data (numpy array) for n_epochs. + """ + + train_cond = self.DisentangleMethod.encode_mean(train_data) + + if use_validation: + valid_cond = self.DisentangleMethod.encode_mean(valid_data) + else: + valid_cond = None + + if test_data is not None: + test_cond = self.DisentangleMethod.encode_mean(test_data) + else: + test_cond = None + + self.GenerationMethod.train_np(train_data = train_data, train_cond = train_cond, use_validation = use_validation, + valid_data = valid_data, valid_cond = valid_cond, use_test_during_train = use_test_during_train, + test_data = test_data, test_cond = test_cond, test_every_n_epochs = test_every_n_epochs, + test_size = test_size, inception_score_data = inception_score_data, n_epochs = n_epochs, + batch_size = batch_size, early_stop_limit = early_stop_limit, threshold = threshold, + shuffle = shuffle, save = save, model_save_path = model_save_path, output_save_path = output_save_path, + verbose = verbose) + + def train_postvars_np(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, + verbose = False): + """ + train GenerationMethod with latent representation conditions of posterior samples + train_data (numpy array) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.GenerationMethod.Diters): + x_mb = self.GenerationMethod.sample_data_np(train_data, batch_size) + z_mb = self.DisentangleMethod.encode(x_mb) + n_mb = self.GenerationMethod.sample_z(batch_size, self.GenerationMethod.n_dim) + self.GenerationMethod.sess.run(self.GenerationMethod.d_solver, + feed_dict = {self.GenerationMethod.x: x_mb, + self.GenerationMethod.z: z_mb, + self.GenerationMethod.noise: n_mb, + self.GenerationMethod.is_training: self.GenerationMethod.if_BNTrainingMode}) + + # G step + x_mb = self.GenerationMethod.sample_data_np(train_data, batch_size) + z_mb = self.DisentangleMethod.encode(x_mb) + n_mb = self.GenerationMethod.sample_z(batch_size, self.GenerationMethod.n_dim) + _, current_loss_D, current_loss_G = self.GenerationMethod.sess.run( + [self.GenerationMethod.g_solver, self.GenerationMethod.D_loss, self.GenerationMethod.G_loss], + feed_dict = {self.GenerationMethod.x: x_mb, + self.GenerationMethod.z: z_mb, + self.GenerationMethod.noise: n_mb, + self.GenerationMethod.is_training: self.GenerationMethod.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.GenerationMethod.sample_data_np(valid_data, batch_size) + z_mb = self.DisentangleMethod.encode(x_mb) + n_mb = self.GenerationMethod.sample_z(batch_size, self.GenerationMethod.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.GenerationMethod.sess.run( + [self.GenerationMethod.D_loss, self.GenerationMethod.G_loss], + feed_dict = {self.GenerationMethod.x: x_mb, + self.GenerationMethod.z: z_mb, + self.GenerationMethod.noise: n_mb, self.GenerationMethod.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.GenerationMethod.train_loss_D.append(train_loss_D) + self.GenerationMethod.train_loss_G.append(train_loss_G) + self.GenerationMethod.valid_loss_D.append(valid_loss_D) + self.GenerationMethod.valid_loss_G.append(valid_loss_G) + self.GenerationMethod.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + test_cond = self.DisentangleMethod.encode(test_data) + + gen_data = self.GenerationMethod.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + test_cond = self.DisentangleMethod.encode(test_data) + gen_data = self.GenerationMethod.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + test_cond = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.GenerationMethod.valid_loss_D[epoch - 2] - self.GenerationMethod.valid_loss_D[epoch - 1]) > min_delta or abs(self.GenerationMethod.valid_loss_G[epoch - 2] - self.GenerationMethod.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) \ No newline at end of file diff --git a/models/ppca.py b/models/ppca.py new file mode 100644 index 0000000..6b15d0f --- /dev/null +++ b/models/ppca.py @@ -0,0 +1,199 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import numpy as np +from scipy import linalg +import logging +import tensorflow as tf +from tensorflow import distributions as ds + +from sklearn.decomposition import PCA +from .util import tf_log, logsumexp, total_correlation, shuffle_adata + +log = logging.getLogger(__file__) + + +class ProbabilisticPCA: + """ + Probabilistic PCA object from the fitted PCA object, + self.z_mean is for the sampled representation using diagonized covariance matrix, + self.z_multivar_mean is with the fully-specified covariance matrix + """ + + def __init__(self, pca_fit): + super().__init__() + self.pca_fit = pca_fit + self.z_dim = self.pca_fit.n_components_ + self.Wmatrix = self.pca_fit.components_.T + self.pca_posterior() + self.sess = tf.Session() + self.mu = tf.placeholder(tf.float32, shape = (None, self.z_dim)) + self.z_multivar_mean = tf.placeholder(tf.float32, shape = (None, self.z_dim)) + self.pca_tensorflow() + + def pca_posterior(self): + """Compute the posterior mean and variance of principal components. + + Returns + ------- + Mmatrix : array, shape=(n_features, n_features) + M matrix + MintWT: the matrix to compute the posterior mean + post_var: posterior variance of the pca + """ + n_features = self.pca_fit.components_.shape[1] + + # handle corner cases first + if self.pca_fit.n_components_ == 0: + return np.eye(n_features) / self.pca_fit.noise_variance_ + if self.pca_fit.n_components_ == n_features: + return linalg.inv(self.pca_fit.get_covariance()) + + # Get precision using matrix inversion lemma + components_ = self.pca_fit.components_ + exp_var = self.pca_fit.explained_variance_ + if self.pca_fit.whiten: + components_ = components_ * np.sqrt(exp_var[:, np.newaxis]) + exp_var_diff = np.maximum(exp_var - self.pca_fit.noise_variance_, 0.) + + precision = np.dot(components_, components_.T) / self.pca_fit.noise_variance_ + precision.flat[::len(precision) + 1] += 1. / exp_var_diff + + self.Mmatrix = precision.copy() + self.MinvWT = np.dot(linalg.inv(precision), components_) + self.post_var = precision/(self.pca_fit.noise_variance_**2) + + def encode_mean(self, x_data): + """ + encode data to the latent means + """ + Xr = x_data - self.pca_fit.mean_ + return np.dot(Xr, self.MinvWT.T) + + def log_prob_z_vector_post(self, z_vector): + """ + log probabilities of latent variables on given posterior normal distributions + """ + ll_con_dist = ds.Normal(self.mu, self.std) + con_gen = ll_con_dist.log_prob(z_vector) + + return con_gen + + def pca_sample(self, mean_tensor, std_value, sample_size = 1): + """ + sample the posterior latent samples for representation + """ + ll_dist = ds.Normal(mean_tensor, std_value) + ll_sample = ll_dist.sample(sample_size) + ll_sample = tf.reshape(ll_sample, tf.shape(ll_sample)[1:]) + + return ll_sample + + def pca_tensorflow(self): + """ + construct the PCA tensors + """ + + self.std = tf.convert_to_tensor(np.sqrt(self.post_var.diagonal())) + self.z_mean = self.pca_sample(self.mu, self.std) + self.z_marginal_entropy, self.z_joint_entropy = self.qz_entropies() + self.z_tc = total_correlation(self.z_marginal_entropy, self.z_joint_entropy) + + + def qz_entropies(self): + """ + estimate the large sample entropies of the q(Z) and q(Zj) + """ + + weights = - tf_log(tf.to_float(tf.shape(self.mu)[0])) + + function_to_map = lambda x: self.log_prob_z_vector_post(tf.reshape(x,[1, self.z_dim])) + logqz_i_m = tf.map_fn(function_to_map, self.z_mean, dtype = tf.float32) + logqz_i_margin = logsumexp(logqz_i_m + weights, dim = 1, keepdims = False) + logqz_value = tf.reduce_sum(logqz_i_m, axis = 2, keepdims = False) + logqz_v_joint = logsumexp(logqz_value + weights, dim = 1, keepdims = False) + logqz_sum = logqz_v_joint + logqz_i_sum = logqz_i_margin + + marginal_entropies = (- tf.reduce_mean(logqz_i_sum, axis = 0)) + joint_entropies = (- tf.reduce_mean(logqz_sum)) + + return marginal_entropies, joint_entropies + + def encode(self, x_data): + """ + encode data to the latent samples + """ + z_post_mean = self.encode_mean(x_data) + z_post_var = self.post_var + + + z_data = None + for t in range(len(z_post_mean)): + z_row = z_post_mean[t] + if z_data is None: + z_data = np.random.multivariate_normal(z_row, z_post_var).reshape((1, self.z_dim)) + else: + z_sample = np.random.multivariate_normal(z_row, z_post_var).reshape((1, self.z_dim)) + z_data = np.concatenate((z_data, z_sample), axis = 0) + + return z_data + + def decode(self, z_data): + """ + decode latent samples to reconstructed data + tensorflow makes the computation faster + """ + # x_dim = self.Wmatrix.shape[0] + + # mu_x = np.dot(z_data, self.Wmatrix.T) + self.pca_fit.mean_ + # std_x = self.pca_fit.noise_variance_**2 + # sigma_x = np.zeros((x_dim, x_dim)) + # np.fill_diagonal(sigma_x, std_x) + + # x_rec_data = None + # for t in range(len(mu_x)): + # mu_row = mu_x[t] + # if x_rec_data is None: + # x_sample = np.random.multivariate_normal(mu_row, sigma_x) + # x_rec_data = x_sample.reshape((1, x_dim)) + # else: + # x_sample = np.random.multivariate_normal(mu_row, sigma_x).reshape((1, x_dim)) + # x_rec_data = np.concatenate((x_rec_data, x_sample), axis = 0) + + # return x_rec_data + x_dim = self.Wmatrix.shape[0] + self.W = tf.convert_to_tensor(self.Wmatrix) + self.mean_x = tf.convert_to_tensor(self.pca_fit.mean_) + + self.mu_x = tf.tensordot(self.z_multivar_mean, tf.transpose(self.W), axes = 1) + self.mean_x + self.scale_x = tf.convert_to_tensor(np.repeat(self.pca_fit.noise_variance_, x_dim)) + + self.sigma_x = tf.reshape(tf.tile(self.scale_x, [tf.shape(self.mu_x)[0]]), + [tf.shape(self.mu_x)[0], tf.shape(self.scale_x)[0]]) + + self.sigma_x = tf.cast(self.sigma_x, tf.float32) + self.x_rec = self.pca_sample(self.mu_x, self.sigma_x) + + feed_dict = {self.z_multivar_mean: z_data} + x_rec_data = self.sess.run(self.x_rec, feed_dict = feed_dict) + + return x_rec_data + + def reconstruct(self, x_data): + """ + reconstruct data from original data + """ + z_data = self.encode(x_data) + x_rec_data = self.decode(z_data) + + return x_rec_data + + + + + + + + + diff --git a/models/ssinfogan.py b/models/ssinfogan.py new file mode 100644 index 0000000..d5ea738 --- /dev/null +++ b/models/ssinfogan.py @@ -0,0 +1,592 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +from scipy import sparse +import pandas as pd + +import tensorflow as tf +from tensorflow import distributions as ds + +from .util import * +from .Adam_prediction import Adam_Prediction_Optimizer +from .gan import BaseGAN + +from metrics.GenerationMetrics import * +log = logging.getLogger(__file__) + + +class ssInfoWGAN_GP(BaseGAN): + + """ + Semi-Supervised Information Maximizing Wasserstein GAN + with gradient penalty (ssInfoWGAN-GP) + """ + def __init__(self, x_dimension, z_dimension = 10, noise_dimension = 118, **kwargs): + + super().__init__(x_dimension, z_dimension, **kwargs) + self.n_dim = noise_dimension + self.noise = tf.placeholder(tf.float32, shape = [None, self.n_dim], name = "latent_noise") + self.lamb_mi_fake = kwargs.get("lamb_mi_fake", 1.0) + self.lamb_mi_real = kwargs.get("lamb_mi_real", 1.0) + + with tf.device(self.device): + + self.x = tf.placeholder(tf.float32, shape = [None, self.x_dim], name = "data") + self.z = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent") + self.y = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent_real") + self.noise = tf.placeholder(tf.float32, shape = [None, self.n_dim], name = "latent_noise") + + self.create_network() + self.loss_function() + + config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + self.sess = tf.Session(config = config) + self.saver = tf.train.Saver(max_to_keep = 1) + self.init = tf.global_variables_initializer().run(session = self.sess) + + def generatorDropOut(self): + """ + generator with dropout layers of ssInfoWGAN-GP + """ + with tf.variable_scope('generatorDropOut', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + ge_dense1 = tf.layers.dropout(ge_dense1, self.dropout_rate, training=self.is_training) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + ge_dense2 = tf.layers.dropout(ge_dense2, self.dropout_rate, training=self.is_training) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + ge_dense3 = tf.layers.dropout(ge_dense3, self.dropout_rate, training=self.is_training) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + def generator(self): + """ + generator without dropout layers of ssInfoWGAN-GP + """ + with tf.variable_scope('generator', reuse = tf.AUTO_REUSE): + + znoise = tf.concat([self.z, self.noise], axis = 1) + + ge_dense1 = tf.layers.dense(inputs = znoise, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + ge_dense1 = tf.layers.batch_normalization(ge_dense1, training = self.is_training) + ge_dense1 = tf.nn.leaky_relu(ge_dense1) + + ge_dense2 = tf.layers.dense(inputs = ge_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + ge_dense2 = tf.layers.batch_normalization(ge_dense2, training = self.is_training) + ge_dense2 = tf.nn.leaky_relu(ge_dense2) + + ge_dense3 = tf.layers.dense(inputs = ge_dense2, units = self.inflate_to_size3, activation=None, + kernel_initializer = self.init_w) + ge_dense3 = tf.layers.batch_normalization(ge_dense3, training = self.is_training) + ge_dense3 = tf.nn.relu(ge_dense3) + + ge_output = tf.layers.dense(inputs = ge_dense3, units= self.x_dim, activation=None) + + return ge_output + + + def discriminatorInfo(self, x_input): + """ + discriminator of ssInfoWGAN-GP + """ + with tf.variable_scope('discriminatorInfo', reuse = tf.AUTO_REUSE): + + disc_dense1 = tf.layers.dense(inputs= x_input, units= self.disc_internal_size1, activation = None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense1 = tf.layers.batch_normalization(disc_dense1, training = self.is_training) + disc_dense1 = tf.nn.leaky_relu(disc_dense1) + + disc_dense2 = tf.layers.dense(inputs = disc_dense1, units= self.disc_internal_size2, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense2 = tf.layers.batch_normalization(disc_dense2, training = self.is_training) + disc_dense2 = tf.nn.leaky_relu(disc_dense2) + + disc_dense3 = tf.layers.dense(inputs=disc_dense2, units= self.disc_internal_size3, activation=None, + kernel_regularizer = self.regu_w, kernel_initializer = self.init_w) + disc_dense3 = tf.layers.batch_normalization(disc_dense3, training = self.is_training) + disc_dense3 = tf.nn.relu(disc_dense3) + + disc_output = tf.layers.dense(inputs=disc_dense3, units= 1, activation=None, + kernel_initializer = self.init_w) + + # Q network part + q_dense1 = tf.layers.dense(inputs = disc_dense3, units = self.disc_internal_size3, activation = None, + kernel_initializer = self.init_w) + q_dense1 = tf.layers.batch_normalization(q_dense1, training = self.is_training) + q_dense1 = tf.nn.leaky_relu(q_dense1) + + q_output = tf.layers.dense(inputs = q_dense1, + units= (self.z_dim if self.fix_std else self.z_dim * 2), + activation = None) + + + return disc_output, q_output + + def c_mutual_info(self, c_vector, z_sample): + """ + function to compute the mutual information lower bound + """ + c_sample = z_sample[:, :self.z_dim] + + if self.fix_std: + mean_vec = c_vector + std_vec = tf.ones_like(mean_vec) + else: + mean_vec = c_vector[:, self.z_dim] + std_vec = c_vector[:, self.z_dim:(self.z_dim * 2)] + std_vec = tf.nn.softplus(std_vec) + dist_c_vector = ds.Normal(mean_vec, std_vec) + + ll_logp = dist_c_vector.log_prob(c_sample) + ll_logp_sum = tf.reduce_sum(ll_logp, [1]) + + mi_bound = tf.reduce_mean(ll_logp_sum) + + return mi_bound + + def create_network(self): + """ + construct the ssInfoWGAN-GP networks + """ + + if self.if_dropout: + self.x_gen_data = self.generatorDropOut() + else: + self.x_gen_data = self.generator() + + self.Dx_real, self.c_mutual_real = self.discriminatorInfo(self.x) + self.Dx_fake, self.c_mutual_fake = self.discriminatorInfo(self.x_gen_data) + + self.q_vector = self.c_mutual_info(self.c_mutual_fake, self.z) + self.q_fake_mutual = tf.reduce_mean(self.q_vector) + + self.q_vector_real = self.c_mutual_info(self.c_mutual_real, self.y) + self.q_real_mutual = tf.reduce_mean(self.q_vector_real) + + def compute_gp(self, x, x_gen_data, discriminator): + """ + gradient penalty of discriminator + """ + epsilon_x = tf.random_uniform([], 0.0, 1.0) + x_hat = x * epsilon_x + (1 - epsilon_x) * x_gen_data + d_hat, _ = discriminator(x_hat) + + gradients = tf.gradients(d_hat, x_hat)[0] + slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) + + gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2) + + return gradient_penalty + + def loss_function(self): + """ + loss function of InfoWGAN-GP + """ + D_raw_loss = tf.reduce_mean(self.Dx_real) - tf.reduce_mean(self.Dx_fake) + self.G_loss = tf.reduce_mean(self.Dx_fake) - (self.q_fake_mutual * self.lamb_mi_fake) - (self.q_real_mutual * self.lamb_mi_real) + + self.gradient_penalty = self.compute_gp(self.x, self.x_gen_data, self.discriminatorInfo) + self.D_loss = D_raw_loss + self.lamb_gp * self.gradient_penalty - (self.q_fake_mutual * self.lamb_mi_fake) - (self.q_real_mutual * self.lamb_mi_real) + + tf_vars_all = tf.trainable_variables() + dvars = [var for var in tf_vars_all if var.name.startswith("discriminatorInfo")] + + if self.if_dropout: + gvars = [var for var in tf_vars_all if var.name.startswith("generatorDropOut")] + else: + gvars = [var for var in tf_vars_all if var.name.startswith("generator")] + + self.parameter_count = tf.reduce_sum( + [tf.reduce_prod(tf.shape(v)) for v in dvars + gvars]) + + + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.g_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=True).minimize(self.G_loss, var_list = gvars) + self.d_solver = Adam_Prediction_Optimizer(learning_rate = self.learning_rate, + beta1=0.9, beta2=0.999, prediction=False).minimize(self.D_loss, var_list = dvars) + + @property + def model_parameter(self): + """ + report the number of training parameters + """ + self.total_param = self.sess.run(self.parameter_count) + return "There are {} parameters in InfoWGAN-GP.".format(self.total_param) + + + def generate_cells(self, z_data): + """ + generate data from latent samples + """ + noise_data = self.sample_z(len(z_data), self.n_dim) + gen_data = self.sess.run(self.x_gen_data, feed_dict = {self.z: z_data, self.noise: noise_data, self.is_training: False}) + return gen_data + + + def restore_model(self, model_path): + """ + restore model from model_path + """ + self.saver.restore(self.sess, model_path) + + + def save_model(self, model_save_path, epoch): + """ + save the trained model to the model_save_path + """ + os.makedirs(model_save_path, exist_ok = True) + model_save_name = os.path.join(model_save_path, "model") + save_path = self.saver.save(self.sess, model_save_name, global_step = epoch) + + np.save(os.path.join(model_save_path, "training_time.npy"), self.training_time) + np.save(os.path.join(model_save_path, "train_loss_D.npy"), self.train_loss_D) + np.save(os.path.join(model_save_path, "train_loss_G.npy"), self.train_loss_G) + np.save(os.path.join(model_save_path, "valid_loss_D.npy"), self.valid_loss_D) + np.save(os.path.join(model_save_path, "valid_loss_G.npy"), self.valid_loss_G) + + def train(self, train_data, train_cond, use_validation = False, valid_data = None, valid_cond = None, use_test_during_train = False, test_data = None, + test_cond = None, test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train ssInfoWGAN-GP with train_data (AnnData), train_cond (numpy) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + if sparse.issparse(train_data.X): + pca_data_fit = pca_data_50.fit(train_data.X.A) + else: + pca_data_fit = pca_data_50.fit(train_data.X) + + train_data_copy = train_data.copy() + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + train_data, train_cond = shuffle_adata_cond(train_data, train_cond) + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb, y_mb = self.sample_data_cond(train_data, train_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.y: y_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb, y_mb = self.sample_data_cond(train_data, train_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.y: y_mb, self.z: z_mb, self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb, y_mb = self.sample_data_cond(valid_data, valid_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.y: y_mb, self.z: z_mb, self.noise: n_mb, self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + if sparse.issparse(train_data_copy.X): + test_data = train_data_copy[sampled_indices, :].X.A + else: + test_data = train_data_copy[sampled_indices, :].X + + test_cond = train_cond[sampled_indices, :] + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + test_cond = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + def train_np(self, train_data, train_cond, use_validation = False, valid_data = None, valid_cond = None, use_test_during_train = False, test_data = None, + test_cond = None, test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 32, + early_stop_limit = 20, threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train ssInfoWGAN-GP with train_data (numpy), train_cond (numpy) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception('valid_data is None but use_validation is True.') + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + train_cond = train_cond[index_shuffle] + + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss_D, train_loss_G = 0.0, 0.0 + valid_loss_D, valid_loss_G = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + + # D step + for _ in range(self.Diters): + x_mb, y_mb = self.sample_data_cond_np(train_data, train_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + self.sess.run(self.d_solver, feed_dict = {self.x: x_mb, self.z: z_mb, self.y: y_mb, + self.noise: n_mb, self.is_training: self.if_BNTrainingMode}) + + # G step + x_mb, y_mb = self.sample_data_cond_np(train_data, train_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + _, current_loss_D, current_loss_G = self.sess.run([self.g_solver, self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.y: y_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: self.if_BNTrainingMode}) + + train_loss_D += (current_loss_D * batch_size) + train_loss_G += (current_loss_G * batch_size) + + train_loss_D /= n_train + train_loss_G /= n_train + + if use_validation: + for _ in range(1, n_valid // batch_size + 1): + x_mb, y_mb = self.sample_data_cond_np(valid_data, valid_cond, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + n_mb = self.sample_z(batch_size, self.n_dim) + + current_loss_valid_D, current_loss_valid_G = self.sess.run([self.D_loss, self.G_loss], + feed_dict = {self.x: x_mb, self.y: y_mb, self.z: z_mb, self.noise: n_mb, + self.is_training: False}) + + valid_loss_D += current_loss_valid_D + valid_loss_G += current_loss_valid_G + + valid_loss_D /= n_valid + valid_loss_G /= n_valid + + self.train_loss_D.append(train_loss_D) + self.train_loss_G.append(train_loss_G) + self.valid_loss_D.append(valid_loss_D) + self.valid_loss_G.append(valid_loss_G) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + test_cond = train_cond[sampled_indices, :] + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.generate_cells(test_cond) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + + if reset_test_data: + test_data = None + test_cond = None + + if verbose: + print(f"Epoch {epoch}: D Train Loss: {train_loss_D} G Train Loss: {train_loss_G} D Valid Loss: {valid_loss_D} G Valid Loss: {valid_loss_G}") + + # early stopping + if use_validation and epoch > 1: + if abs(self.valid_loss_D[epoch - 2] - self.valid_loss_D[epoch - 1]) > min_delta or abs(self.valid_loss_G[epoch - 2] - self.valid_loss_G[epoch - 1]) > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) diff --git a/models/util.py b/models/util.py new file mode 100644 index 0000000..8da4362 --- /dev/null +++ b/models/util.py @@ -0,0 +1,68 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import tensorflow as tf +import numpy as np +import sys +import pandas as pd +from random import shuffle +from scipy import sparse + + +def tf_log(value): + """ + tensorflow logrithmic + """ + return tf.log(value + 1e-16) + +def logsumexp(value, dim = None, keepdims = False): + """ + calculate q(Z) = sum_{X}[sum_{j}q(Zj|X)] and q(Zj) = sum_{X}[q(Zj|X)] + """ + if dim is not None: + m = tf.reduce_max(value, axis = dim, keepdims = True) + value0 = tf.subtract(value, m) + if keepdims is False: + m = tf.squeeze(m, dim) + return tf.add(m, tf_log(tf.reduce_sum(tf.exp(value0), axis = dim, keepdims = keepdims))) + + else: + m = tf.reduce_max(value) + sum_exp = tf.reduce_sum(tf.exp(tf.subtract(value, m))) + return tf.add(m, tf_log(sum_exp)) + + +def total_correlation(marginal_entropies, joint_entropies): + """ + calculate total correlation from the marginal and joint entropies + """ + return tf.reduce_sum(marginal_entropies) - tf.reduce_sum(joint_entropies) + + +def shuffle_adata(adata): + """ + shuffle adata + """ + if sparse.issparse(adata.X): + adata.X = adata.X.A + + ind_list = list(range(adata.shape[0])) + shuffle(ind_list) + new_adata = adata[ind_list, :] + + return new_adata + + +def shuffle_adata_cond(adata, cond): + """ + Shuffle adata with the label + """ + + if sparse.issparse(adata.X): + adata.X = adata.X.A + ind_list = list(range(adata.shape[0])) + shuffle(ind_list) + + new_adata = adata[ind_list, :] + new_cond = cond[ind_list, :] + + return new_adata, new_cond \ No newline at end of file diff --git a/models/vae.py b/models/vae.py new file mode 100644 index 0000000..156a043 --- /dev/null +++ b/models/vae.py @@ -0,0 +1,630 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import os +import time +import logging + +import numpy as np +import pandas as pd +from scipy import sparse + +import tensorflow as tf +from tensorflow import distributions as ds + +from .util import tf_log, logsumexp, total_correlation, shuffle_adata +from metrics.GenerationMetrics import * + +log = logging.getLogger(__file__) + +class BetaTCVAE: + """ + General VAE (beta = 0) and beta-TCVAE class + """ + def __init__(self, num_cells_train, x_dimension, z_dimension = 10, **kwargs): + tf.compat.v1.reset_default_graph() + self.num_cells_train = num_cells_train + self.x_dim = x_dimension + self.z_dim = z_dimension + self.learning_rate = kwargs.get("learning_rate", 1e-3) + self.dropout_rate = kwargs.get("dropout_rate", 0.2) + self.beta = kwargs.get("beta", 0.0) + self.alpha = kwargs.get("alpha", 1.0) + self.inflate_to_size1 = kwargs.get("inflate_size_1", 256) + self.inflate_to_size2 = kwargs.get("inflate_size_2", 512) + self.disc_internal_size2 = kwargs.get("disc_size_2", 512) + self.disc_internal_size3 = kwargs.get("disc_size_3", 256) + self.if_BNTrainingMode = kwargs.get("BNTrainingMode", True) + self.is_training = tf.placeholder(tf.bool, name = "training_flag") + + self.init_w = tf.contrib.layers.xavier_initializer() + self.device = kwargs.get("device", '/device:GPU:0') + + with tf.device(self.device): + self.x = tf.placeholder(tf.float32, shape = [None, self.x_dim], name = "data") + self.z = tf.placeholder(tf.float32, shape = [None, self.z_dim], name = "latent") + self.create_network() + self.loss_function() + + config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) + config.gpu_options.per_process_gpu_memory_fraction = 0.6 + self.sess = tf.Session(config = config) + + self.saver = tf.train.Saver(max_to_keep = 1) + self.init = tf.global_variables_initializer().run(session = self.sess) + self.train_loss = [] + self.valid_loss = [] + self.training_time = 0.0 + + def encoder(self): + """ + encoder of VAE + """ + with tf.variable_scope('encoder', reuse = tf.AUTO_REUSE): + en_dense2 = tf.layers.dense(inputs = self.x, units = self.inflate_to_size2, activation = None, + kernel_initializer = self.init_w) + en_dense2 = tf.layers.batch_normalization(en_dense2, training = self.is_training) + en_dense2 = tf.nn.leaky_relu(en_dense2) + en_dense2 = tf.layers.dropout(en_dense2, self.dropout_rate, training = self.is_training) + + en_dense3 = tf.layers.dense(inputs = en_dense2, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + en_dense3 = tf.layers.batch_normalization(en_dense3, training = self.is_training) + en_dense3 = tf.nn.relu(en_dense3) + en_dense3 = tf.layers.dropout(en_dense3, self.dropout_rate, training = self.is_training) + + en_loc = tf.layers.dense(inputs=en_dense3, units= self.z_dim, activation=None, kernel_initializer = self.init_w) + en_scale = tf.layers.dense(inputs = en_dense3, units= self.z_dim, activation=None, kernel_initializer = self.init_w) + en_scale = tf.nn.softplus(en_scale) + return en_loc, en_scale + + def decoder(self): + """ + decoder of VAE + """ + with tf.variable_scope("decoder", reuse = tf.AUTO_REUSE): + de_dense1 = tf.layers.dense(inputs = self.z_mean, units = self.inflate_to_size1, activation = None, + kernel_initializer = self.init_w) + de_dense1 = tf.layers.batch_normalization(de_dense1, training = self.is_training) + de_dense1 = tf.nn.leaky_relu(de_dense1) + de_dense1 = tf.layers.dropout(de_dense1, self.dropout_rate, training = self.is_training) + + de_dense2 = tf.layers.dense(inputs=de_dense1, units = self.inflate_to_size2, activation=None, + kernel_initializer = self.init_w) + de_dense2 = tf.layers.batch_normalization(de_dense2, training = self.is_training) + de_dense2 = tf.nn.leaky_relu(de_dense2) + de_dense2 = tf.layers.dropout(de_dense2, self.dropout_rate, training = self.is_training) + + de_loc = tf.layers.dense(inputs=de_dense2, units= self.x_dim, activation=None, kernel_initializer = self.init_w) + de_scale = tf.ones_like(de_loc) + return de_loc, de_scale + + def sample_posterior_z(self): + """ + sample the posterior latent samples for representation + """ + batch_size = tf.shape(self.mu)[0] + eps = tf.random_normal(shape = [batch_size, self.z_dim]) + return self.mu + self.std * eps + + def sample_x(self): + """ + sample the reconstructed data + """ + batch_size_x = tf.shape(self.mu_x)[0] + eps_x = tf.random_normal(shape = [batch_size_x, self.x_dim]) + return self.mu_x + self.std_x * eps_x + + def sample_z(self, batch_size, z_dim): + """ + sample the standard normal noises + """ + return np.random.normal(0.0, scale = 1.0, size = (batch_size, z_dim)) + + def sample_data(self, data, batch_size): + """ + sample data from AnnData datatype + """ + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + if sparse.issparse(data.X): + x_mb = data[lower:upper, :].X.A + else: + x_mb = data[lower:upper, :].X + return x_mb + + def sample_data_np(self, data, batch_size): + """ + sample data from numpy array datatype + """ + + lower = np.random.randint(0, data.shape[0] - batch_size) + upper = lower + batch_size + + return data[lower:upper] + + def log_prob_z_prior_dist(self): + """ + tensorflow prior distribution of latent variables + """ + batch_size = tf.shape(self.mu)[0] + shape = [batch_size, self.z_dim] + return ds.Normal(tf.zeros(shape), tf.ones(shape)) + + def log_prob_z_prior(self): + """ + log probabilities of posterior latent samples on the prior + distribution + """ + return self.log_prob_z_prior_dist().log_prob(self.z_mean) + + def log_prob_x_dist(self): + """ + tensorflow normal distribution of the reconstructed data + """ + return ds.Normal(self.mu_x, self.std_x) + + def log_prob_z_post(self): + """ + log probabilities of posterior latent samples from their posterior distributions + """ + z_norm = (self.z_mean - self.mu) / self.std + z_var = tf.square(self.std) + return -0.5 * (z_norm * z_norm + tf.log(z_var) + np.log(2*np.pi)) + + def qz_mss_entropies(self): + """ + estimate the minibatch entropies of the q(Z) and q(Zj) using + Minibatch Stratified Sampling (MSS) + """ + dataset_size = tf.convert_to_tensor(self.num_cells_train) + batch_size = tf.shape(self.z_mean)[0] + # compute the weights + output = tf.zeros((batch_size - 1, 1)) + output = tf.concat([tf.ones((1,1)), output], axis = 0) + outpart_1 = tf.zeros((batch_size, 1)) + outpart_3 = tf.zeros((batch_size, batch_size - 2)) + output = tf.concat([outpart_1, output], axis = 1) + part_4 = - tf.concat([output, outpart_3], axis = 1)/tf.to_float(dataset_size) + + part_1 = tf.ones((batch_size, batch_size))/tf.to_float(batch_size - 1) + part_2 = tf.ones((batch_size, batch_size)) + part_2 = - tf.matrix_band_part(part_2, 1, 0)/tf.to_float(dataset_size) + + part_3 = tf.eye(batch_size) * (2/tf.to_float(dataset_size) - 1/tf.to_float(batch_size - 1)) + + weights = tf_log(part_1 + part_2 + part_3 + part_4) + + # the entropies + function_to_map = lambda x: self.log_prob_z_vector_post(tf.reshape(x, [1, self.z_dim])) + logqz_i_m = tf.map_fn(function_to_map, self.z_mean, dtype = tf.float32) + weights_expand = tf.expand_dims(weights, 2) + logqz_i_margin = logsumexp(logqz_i_m + weights_expand, dim = 1, keepdims = False) + logqz_value = tf.reduce_sum(logqz_i_m, axis = 2, keepdims = False) + logqz_v_joint = logsumexp(logqz_value + weights, dim = 1, keepdims = False) + logqz_sum = logqz_v_joint + logqz_i_sum = logqz_i_margin + + marginal_entropies = (- tf.reduce_mean(logqz_i_sum, axis = 0)) + joint_entropies = (- tf.reduce_mean(logqz_sum)) + + return marginal_entropies, joint_entropies + + def qz_entropies(self): + """ + estimate the large sample entropies of the q(Z) and q(Zj) + """ + batch_size = tf.shape(self.mu)[0] + weights = - tf_log(tf.to_float(batch_size)) + + function_to_map = lambda x: self.log_prob_z_vector_post(tf.reshape(x, [1, self.z_dim])) + logqz_i_m = tf.map_fn(function_to_map, self.z_mean, dtype = tf.float32) + logqz_i_margin = logsumexp(logqz_i_m + weights, dim = 1, keepdims = False) + logqz_value = tf.reduce_sum(logqz_i_m, axis = 2, keepdims = False) + logqz_v_joint = logsumexp(logqz_value + weights, dim = 1, keepdims = False) + logqz_sum = logqz_v_joint + logqz_i_sum = logqz_i_margin + + marginal_entropies = (- tf.reduce_mean(logqz_i_sum, axis = 0)) + joint_entropies = (- tf.reduce_mean(logqz_sum)) + + return marginal_entropies, joint_entropies + + + def create_network(self): + """ + construct the VAE networks + """ + self.mu, self.std = self.encoder() + self.z_mean = self.sample_posterior_z() + self.z_mss_marginal_entropy, self.z_mss_joint_entropy = self.qz_mss_entropies() + self.z_marginal_entropy, self.z_joint_entropy = self.qz_entropies() + self.z_tc = total_correlation(self.z_marginal_entropy, self.z_joint_entropy) + self.z_mss_tc = total_correlation(self.z_mss_marginal_entropy, self.z_mss_joint_entropy) + self.mu_x, self.std_x = self.decoder() + self.x_hat = self.sample_x() + + def loss_function(self): + """ + loss function of VAEs + """ + # KL divergence + z_posterior = self.log_prob_z_post() + z_prior = self.log_prob_z_prior() + z_prior_sample = tf.reduce_sum(z_prior, [1]) + z_post_sample = tf.reduce_sum(z_posterior, [1]) + self.kl_loss = - tf.reduce_mean(z_prior_sample) + tf.reduce_mean(z_post_sample) + + # reconstruction error + log_prob_x = self.log_prob_x_dist().log_prob(self.x) + log_prob_x_sample = tf.reduce_sum(log_prob_x, [1]) + self.rec_x_loss = - tf.reduce_mean(log_prob_x_sample) + + # variables + tf_vars_all = tf.trainable_variables() + evars = [var for var in tf_vars_all if var.name.startswith("encoder")] + dvars = [var for var in tf_vars_all if var.name.startswith("decoder")] + + self.parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in evars + dvars]) + + + # Total correlation is self.z_mss_tc + self.tcvae_loss = self.alpha * self.kl_loss + self.rec_x_loss + self.beta * self.z_mss_tc + with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.tcvae_loss, + var_list = evars + dvars) + + def log_prob_z_vector_post(self, z_vector): + """ + log probabilities of latent variables on given posterior normal distributions + """ + z_norm = (z_vector - self.mu) / self.std + z_var = tf.square(self.std) + return -0.5 * (z_norm * z_norm + tf.log(z_var) + np.log(2*np.pi)) + + def encode(self, x_data): + """ + encode data to the latent samples + """ + latent = self.sess.run(self.z_mean, feed_dict = {self.x: x_data, self.is_training: False}) + return latent + + def encode_mean(self, x_data): + """ + encode data to the latent means + """ + latent = self.sess.run(self.mu, feed_dict = {self.x: x_data, self.is_training: False}) + return latent + + + def avg_vector(self, data): + """ + encode data to the latent sample means + """ + latent = self.encode(data) + latent_avg = np.average(latent, axis = 0) + return latent_model_parameter + + @property + def model_parameter(self): + """ + report the number of training parameters + """ + self.total_param = self.sess.run(self.parameter_count) + return "There are {} parameters in VAE.".format(self.total_param) + + + + def reconstruct(self, data, if_latent = False): + """ + reconstruct data from original data or latent samples + """ + if if_latent: + latent = data + else: + latent = self.encode(data) + + rec_data = self.sess.run(self.x_hat, feed_dict = {self.z_mean: latent, self.is_training: False}) + return rec_data + + def restore_model(self, model_path): + """ + restore model from model_path + """ + self.saver.restore(self.sess, model_path) + + def save_model(self, model_save_path, epoch): + """ + save the trained model to the model_save_path + """ + os.makedirs(model_save_path, exist_ok = True) + model_save_name = os.path.join(model_save_path, "model") + save_path = self.saver.save(self.sess, model_save_name, global_step = epoch) + + np.save(os.path.join(model_save_path, "training_time.npy"), self.training_time) + np.save(os.path.join(model_save_path, "train_loss.npy"), self.train_loss) + np.save(os.path.join(model_save_path, "valid_loss.npy"), self.valid_loss) + + + def train(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 128, early_stop_limit =20, + threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train VAE with train_data (AnnData) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception("valid_data is None but use_validation is True.") + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + if sparse.issparse(train_data.X): + pca_data_fit = pca_data_50.fit(train_data.X.A) + else: + pca_data_fit = pca_data_50.fit(train_data.X) + + train_data_copy = train_data.copy() + + for epoch in range(1, n_epochs + 1): + + begin = time.time() + if shuffle: + train_data = shuffle_adata(train_data) + train_loss, valid_loss = 0.0, 0.0 + for _ in range(1, n_train // batch_size + 1): + x_mb = self.sample_data(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + _, current_loss_train = self.sess.run([self.solver, self.tcvae_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, + self.is_training: self.if_BNTrainingMode}) + + train_loss += current_loss_train * batch_size + + train_loss /= n_train + + if use_validation: + valid_loss = 0 + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + current_loss_valid = self.sess.run(self.tcvae_loss, + feed_dict = {self.x: x_mb, self.z: z_mb, + self.is_training: False}) + valid_loss += current_loss_valid * batch_size + + valid_loss /= n_valid + + self.train_loss.append(train_loss) + self.valid_loss.append(valid_loss) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + if sparse.issparse(train_data_copy.X): + test_data = train_data_copy[sampled_indices, :].X.A + else: + test_data = train_data_copy[sampled_indices, :].X + + gen_data = self.reconstruct(test_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.reconstruct(test_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + if reset_test_data: + test_data = None + + + if verbose: + print(f"Epoch {epoch}: Train Loss: {train_loss} Valid Loss: {valid_loss}") + + # early stopping + if use_validation and epoch > 1: + if self.valid_loss[epoch - 2] - self.valid_loss[epoch - 1] > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + + def train_np(self, train_data, use_validation = False, valid_data = None, use_test_during_train = False, test_data = None, + test_every_n_epochs = 100, test_size = 3000, inception_score_data = None, n_epochs = 25, batch_size = 128, early_stop_limit =20, + threshold = 0.0025, shuffle = True, save = False, model_save_path = None, output_save_path = None, verbose = False): + """ + train VAE with train_data (numpy array) and optional valid_data (numpy array) for n_epochs. + """ + log.info("--- Training ---") + if use_validation and valid_data is None: + raise Exception("valid_data is None but use_validation is True.") + + patience = early_stop_limit + min_delta = threshold + patience_cnt = 0 + + n_train = train_data.shape[0] + n_valid = None + if use_validation: + n_valid = valid_data.shape[0] + + # generation performance at the PC space + if use_test_during_train: + pca_data_50 = PCA(n_components = 50, random_state = 42) + genmetric = MetricVisualize() + RFE = RandomForestError() + genmetrics_pd = pd.DataFrame({'epoch':[], 'is_real_mu': [], 'is_real_std': [], + 'is_fake_mu':[], 'is_fake_std':[], 'rf_error':[]}) + + pca_data_fit = pca_data_50.fit(train_data) + + if shuffle: + index_shuffle = list(range(n_train)) + + for epoch in range(1, n_epochs + 1): + + begin = time.time() + + if shuffle: + np.random.shuffle(index_shuffle) + train_data = train_data[index_shuffle] + + if inception_score_data is not None: + inception_score_data = inception_score_data[index_shuffle] + + train_loss, valid_loss = 0.0, 0.0 + + for _ in range(1, n_train // batch_size + 1): + x_mb = self.sample_data_np(train_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + _, current_loss_train = self.sess.run([self.solver, self.tcvae_loss], + feed_dict = {self.x: x_mb, self.z: z_mb, + self.is_training: self.if_BNTrainingMode}) + + train_loss += current_loss_train * batch_size + + train_loss /= n_train + + if use_validation: + valid_loss = 0 + for _ in range(1, n_valid // batch_size + 1): + x_mb = self.sample_data_np(valid_data, batch_size) + z_mb = self.sample_z(batch_size, self.z_dim) + current_loss_valid = self.sess.run(self.tcvae_loss, + feed_dict = {self.x: x_mb, self.z: z_mb, + self.is_training: False}) + valid_loss += current_loss_valid * batch_size + + valid_loss /= n_valid + + self.train_loss.append(train_loss) + self.valid_loss.append(valid_loss) + self.training_time += (time.time() - begin) + + # testing for generation metrics + if (epoch - 1) % test_every_n_epochs == 0 and use_test_during_train: + + if test_data is None: + reset_test_data = True + sampled_indices = sample(range(n_train), test_size) + + test_data = train_data[sampled_indices, :] + gen_data = self.reconstruct(test_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data[sampled_indices] + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + else: + assert test_data.shape[0] == test_size + reset_test_data = False + + gen_data = self.reconstruct(test_data) + + if inception_score_data is not None: + inception_score_subdata = inception_score_data + mean_is_real, std_is_real = genmetric.InceptionScore(test_data, inception_score_subdata, test_data) + mean_is_fake, std_is_fake = genmetric.InceptionScore(test_data, inception_score_subdata, gen_data) + else: + mean_is_real = std_is_real = mean_is_fake = std_is_fake = 0.0 + + + errors_d = list(RFE.fit(test_data, gen_data, pca_data_fit, if_dataPC = True, output_AUC = False)['avg'])[0] + genmetrics_pd = pd.concat([genmetrics_pd, pd.DataFrame([[epoch, mean_is_real, std_is_real, mean_is_fake, std_is_fake, + errors_d]], columns = ['epoch', 'is_real_mu', 'is_real_std', 'is_fake_mu', 'is_fake_std', 'rf_error'])]) + + if save: + genmetrics_pd.to_csv(os.path.join(output_save_path, "GenerationMetrics.csv")) + if reset_test_data: + test_data = None + + + if verbose: + print(f"Epoch {epoch}: Train Loss: {train_loss} Valid Loss: {valid_loss}") + + # early stopping + if use_validation and epoch > 1: + if self.valid_loss[epoch - 2] - self.valid_loss[epoch - 1] > min_delta: + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt > patience: + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training stopped earlier at epoch: {epoch}.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) + break + + + if save: + self.save_model(model_save_path, epoch) + log.info(f"Model saved in file: {model_save_path}. Training finished.") + if verbose: + print(f"Model saved in file: {model_save_path}. Training finished.") + if use_test_during_train: + genmetrics_pd.to_csv(os.path.join(model_save_path, "GenerationMetrics.csv")) +