Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GANs mnist #172

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mnist_gan/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.ipynb_checkpoints
swaingotnochill marked this conversation as resolved.
Show resolved Hide resolved
saved_models
samples_posterior
samples_csv_files
mnist_first250_training_4s_and_9s.arm
mnist_gan
mnist_gan_generate
mnist_gan_generate.o
mnist_gan.o
36 changes: 36 additions & 0 deletions mnist_gan/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

TARGET := mnist_gan_generate
SRC := mnist_gan_generate.cpp
LIBS_NAME := armadillo mlpack

CXX := g++
CXXFLAGS += -std=c++11 -Wall -Wextra -O3 -DNDEBUG
# Use these CXXFLAGS instead if you want to compile with debugging symbols and
# without optimizations.
# CXXFLAGS += -std=c++11 -Wall -Wextra -g -O0
LDFLAGS += -fopenmp
LDFLAGS += -lboost_serialization
LDFLAGS += -larmadillo
LDFLAGS += -L /home/viole/mlpack/build/lib/ # /path/to/mlpack/library/ # if installed locally.
# Add header directories for any includes that aren't on the
# default compiler search path.
INCLFLAGS := -I /home/viole/mlpac/build/include/
CXXFLAGS += $(INCLFLAGS)

OBJS := $(SRC:.cpp=.o)
LIBS := $(addprefix -l,$(LIBS_NAME))
CLEAN_LIST := $(TARGET) $(OBJS)

# default rule
default: all

$(TARGET): $(OBJS)
$(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS)

.PHONY: all
all: $(TARGET)

.PHONY: clean
clean:
@echo CLEAN $(CLEAN_LIST)
@rm -f $(CLEAN_LIST)
42 changes: 42 additions & 0 deletions mnist_gan/gan_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* @file gan_utils.cpp
* @author Roshan Swain
* @author Atharva Khandait
*
* Utility function necessary for working with GAN models.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MODELS_GAN_UTILS_HPP
#define MODELS_GAN_UTILS_HPP

#include <mlpack/core.hpp>
#include <mlpack/methods/ann/ffn.hpp>

using namespace mlpack;
using namespace mlpack::ann;

// Sample from the output distribution and post-process the outputs(because
// we pre-processed it before passing it to the model).
template<typename DataType = arma::mat>
void GetSample(DataType &input, DataType& samples, bool isBinary)
{
if (isBinary)
{
samples = arma::conv_to<DataType>::from(
arma::randu<DataType>(input.n_rows, input.n_cols) <= input);
samples *= 255;
}
else
{
samples = input / 2 + 0.5;
samples *= 255;
samples = arma::clamp(samples, 0, 255);
}
}

#endif
91 changes: 91 additions & 0 deletions mnist_gan/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
@file generate_images.py
@author Atharva Khandait
Generates jpg files from csv.
mlpack is free software; you may redistribute it and/or modify it under the
terms of the 3-clause BSD license. You should have received a copy of the
3-clause BSD license along with mlpack. If not, see
http://www.opensource.org/licenses/BSD-3-Clause for more information.
"""

from PIL import Image
import numpy as np
import cv2
import os

def ImagesFromCSV(filename,
imgShape = (28, 28),
destination = 'samples',
saveIndividual = False):

# Import the data into a numpy matrix.
samples = np.genfromtxt(filename, delimiter = ',', dtype = np.uint8)

# Reshape and save it as an image in the destination.
tempImage = Image.fromarray(np.reshape(samples[:, 0], imgShape), 'L')
if saveIndividual:
tempImage.save(destination + '/sample0.jpg')

# All the images will be concatenated to this for a combined image.
allSamples = tempImage

for i in range(1, samples.shape[1]):
tempImage = np.reshape(samples[:, i], imgShape)

allSamples = np.concatenate((allSamples, tempImage), axis = 1)

tempImage = Image.fromarray(tempImage, 'L')
if saveIndividual:
tempImage.save(destination + '/sample' + str(i) + '.jpg')

tempImage = allSamples
allSamples = Image.fromarray(allSamples, 'L')
allSamples.save(destination + '/allSamples' + '.jpg')

print ('Samples saved in ' + destination + '/.')

return tempImage

# Save posterior samples.
ImagesFromCSV('./samples_csv_files/samples_posterior.csv', destination =
'samples_posterior')

# Save prior samples with individual latent varying.
latentSize = 10
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent0.csv',
destination = 'samples_prior')

for i in range(1, latentSize):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent' + str(i) + '.csv',
destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/allLatent.jpg')

# Save prior samples with 2d latent varying.
nofSamples = 20
allLatent = ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d0.csv',
destination = 'latent')

for i in range(1, nofSamples):
allLatent = np.concatenate((allLatent,
(ImagesFromCSV('./samples_csv_files/samples_prior_latent_2d' + str(i) +
'.csv', destination = 'samples_prior'))), axis = 0)

saved = Image.fromarray(allLatent, 'L')
saved.save('./samples_prior/2dLatent.jpg')

# AVI file
vid_fname = 'gans_celebface_training1.avi'
sample_dir = " "

files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'generated' in f]
files.sort()

out = cv2.VideoWriter(vid_fname, cv2.VideoWriter_fourcc(*'MP4V'), 1, (530, 530))
[out.write(cv2.imread(fname)) for fname in files]
out.release()


###Output###
Loading