Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
synthetic-data/dp_wgan_gp.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
903 lines (727 sloc)
32.7 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Differentially Private | |
Wasserstein GAN with gradient penalty | |
@author: Yale et al., ESANN 2019 | |
adapted by Joseph Pedersen | |
1 2 3 4 5 6 7 | |
1234567890123456789012345678901234567890123456789012345678901234567890123456789 | |
""" | |
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import sys | |
import time | |
from datetime import datetime | |
import pickle as pkl | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
print('tf',tf.__version__) | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3" | |
from tf_wgan_bds import computeBounds | |
myHome = "/home/username/" | |
thisDIR = myHome + "dataset1/" | |
outDIR = thisDIR + "output/" | |
recordGrads = False # record the gradients | |
recGradEvery = 10000 # number of epochs between recording gradients | |
def data_batcher(data, batch_size): | |
"""create yield function for given data and batch size""" | |
def get_all_batches(): | |
"""yield function (generator) for all batches in data""" | |
# shuffle in place each time | |
np.random.shuffle(data) | |
# get total number of evenly divisible batchs | |
# shape of (num_batches, batch_size, n_features) | |
batches = data[:(data.shape[0] // batch_size) * batch_size] | |
batches = batches.reshape(-1, batch_size, data.shape[1]) | |
# go through all batches and yield them | |
for i, _ in enumerate(batches): | |
yield np.copy(batches[i]) | |
def infinite_data_batcher(): | |
"""a generator that yields new batches every time it is called""" | |
# once we run out of batches start over | |
while True: | |
# for each batch in one set of batches | |
for batch in get_all_batches(): | |
yield batch | |
return infinite_data_batcher() | |
class WGAN(): | |
"""Wasserstein GAN with gradient penalties""" | |
def __init__(self, | |
outDIR=None, | |
trng_path=None, | |
test_path=None, | |
genL0_nodes=None, | |
critL1_nodes=None, | |
critic_iters=None, | |
num_epochs=None, | |
batch_size=None, | |
lamda=None, | |
N_synth=None, | |
S_synth=None, | |
DP_sigma=None | |
): | |
self.params = { | |
'outDIR': None, # path to directory to write log and dump saves | |
'trng_path': None, # path/filename of training data | |
'test_path': None, # path/filename of test data | |
'genL0_nodes': 100, # size of random vector for input to generator | |
'critL1_nodes': 64, # num nodes in 1st hidden layer of critc | |
'critic_iters': 5, # num critic iters per generator iter | |
'num_epochs': 100000, # num generator iters | |
'batch_size': None, # batch size for training | |
'lamda': 10, # multiplier for gradient penalty term of loss fn | |
'N_synth': 0, # num of synthetic samples to generate after trng | |
'S_synth': 10000, # size of synthetic samples to generate | |
'DP_sigma': 0 # standard deviation to use for DP noise | |
} | |
# path to save the output, and to the training/test data | |
if not outDIR: | |
raise ValueError("Where should output be saved?") | |
self.params['outDIR'] = outDIR | |
if (not trng_path) or (not test_path): | |
raise ValueError("Where is the data?") | |
self.params['trng_path'] = trng_path | |
self.params['test_path'] = test_path | |
# read in data | |
scratch = pd.read_csv(trng_path) | |
self.col_names = scratch.columns | |
self.train_data = scratch.values | |
self.test_data = pd.read_csv(test_path).values | |
################################### | |
# setting the sizes of the layers # | |
################################### | |
# size of random vector for input to generator | |
if genL0_nodes: | |
self.params['genL0_nodes'] = genL0_nodes | |
# layers of generator | |
self.params['n_features'] = self.train_data.shape[1] | |
self.params['genL1_nodes'] = 2 * self.params['n_features'] | |
self.params['genL2_nodes'] = 3 * self.params['n_features'] // 2 | |
self.params['genL3_nodes'] = self.params['n_features'] | |
# layers of critic | |
if critL1_nodes: | |
self.params['critL1_nodes'] = critL1_nodes | |
self.params['critL2_nodes'] = 2 * self.params['critL1_nodes'] | |
self.params['critL3_nodes'] = 4 * self.params['critL1_nodes'] | |
# save critic shapes for adding noise | |
self.crit_grad_shapes = [ | |
(self.params['n_features'], self.params['critL1_nodes']), | |
(self.params['critL1_nodes'], ), | |
(self.params['critL1_nodes'], self.params['critL2_nodes']), | |
(self.params['critL2_nodes'], ), | |
(self.params['critL2_nodes'], self.params['critL3_nodes']), | |
(self.params['critL3_nodes'], ), | |
(self.params['critL3_nodes'], 1), | |
(1, ) | |
] | |
################################### | |
# setting other hyper-parameters # | |
################################### | |
# number of critic trng iterations for each generator trng iteration | |
if critic_iters: self.params['critic_iters'] = critic_iters | |
# number of generator iterations (also called epochs) | |
if num_epochs: self.params['num_epochs'] = num_epochs | |
self.params['n_observations'] = self.train_data.shape[0] | |
if batch_size: | |
self.params['batch_size'] = batch_size | |
else: | |
# batch_size is num observations divided by num of critic | |
# iterations rounded down to the nearest multiple of 100. | |
self.params['batch_size'] = ( | |
(int(self.train_data.shape[0] / | |
self.params['critic_iters']) | |
// 100) | |
* 100) | |
# weight for gradient penalty term of loss function | |
if lamda: | |
self.params['lamda'] = lamda | |
# number & size of synthetic samples to generate | |
if N_synth: | |
self.params['N_synth'] = N_synth | |
if S_synth: | |
self.params['S_synth'] = S_synth | |
# differential privacy noise parameter | |
if DP_sigma: | |
self.params['DP_sigma'] = DP_sigma | |
# double check on sizing | |
assert self.test_data.shape[0] > self.params['batch_size'] | |
assert (self.train_data.shape[0] / self.params['batch_size'] | |
>= self.params['critic_iters']) | |
# define the method train_batcher using the def of data_batcher | |
self.train_batcher = data_batcher(self.train_data, | |
self.params['batch_size']) | |
############################## | |
# declaring other attributes # | |
# that will be set later # | |
############################## | |
self.real_data = None # real data | |
self.fake_data = None # fake data = gen(noise) | |
self.crit_real = None # crit(real_data) | |
self.crit_fake = None # crit(fake_data) | |
self.crit_grad_penalties = None # grad penalty term | |
# loss function terms | |
self.gen_loss = None # -mean(crit_fake) | |
self.crit_real_loss = None # -mean(crit_real) | |
self.crit_fake_loss = None # mean(crit_fake) | |
self.crit_main_loss = None # mean(crit_fake) - mean(crit_real) | |
self.crit_gp_loss = None # grad penalty term of loss function | |
self.crit_loss = None # crit_main_loss + crit_gp_loss | |
# training parameters, gradients, and training operations | |
self.gen_params = None # TF trainable params for generator | |
self.gen_grads = None # gradient of those params | |
self.gen_train_op = None # AdamOptimizer.minimize | |
self.crit_params = None # TF trainable params for critic | |
self.critic_wts = None # the weight matrices for the critic | |
self.critic_bias = None # the biases for the critic | |
self.crit_real_grads = None # grads for crit_real_loss | |
self.crit_fake_grads = None # grads for crit_fake_loss | |
self.crit_main_grads = None # grads for crit_main_loss | |
self.crit_gp_grads = None # grads for crit_gp_loss | |
self.crit_grads = None # grads for crit_loss | |
self.dpn = None # noise added for differential privacy | |
self.noisy_crit_grads = None # noisy grads for crit_loss | |
self.bds = None # the computed bounds on the gradients | |
self.wtBds = None # norm of bounds on gradients of weights, by layer | |
self.biasBds = None # norm of bounds on gradients of bias, by layer | |
self.gpBds = None # bounds on norm of gradients of gp term, by layer | |
self.sensitivity = None # the total L2 sensitivity each critic iter | |
self.crit_train_op = None | |
# noise for generating synthetic data after training | |
self.rand_noise_samples = None | |
self.samples_scores = None # critic scores of synthetic samples | |
self.real_samples = None # samples of real data for comparison | |
self.real_scores = None # critic scores of the real samples | |
############################## | |
# define lists to store data # | |
############################## | |
# loss function values during training | |
self.gen_loss_all = [] | |
self.crit_fake_loss_all = [] | |
self.crit_real_loss_all = [] | |
self.crit_main_loss_all = [] | |
self.crit_loss_all = [] | |
# loss on test set | |
self.crit_main_loss_test_all = [] | |
self.crit_loss_test_all = [] | |
# time | |
self.time_all = [] | |
# gradients during training | |
self.gen_grads_all = [] | |
self.crit_real_grads_all = [] | |
self.crit_fake_grads_all = [] | |
self.crit_gp_grads_all = [] | |
self.crit_grads_all = [] | |
self.dpn_all = [] | |
self.noisy_crit_grads_all = [] | |
self.wtBds_all = [] | |
self.biasBds_all = [] | |
self.gpBds_all = [] | |
self.sensitivity_all = [] | |
### END of def __init__(self) ### | |
def print_settings(self, myfile=None): | |
"""print the WGAN parameters""" | |
print("\nWGAN parameters:", file=myfile) | |
for k, v in self.params.items(): | |
print(f'{k + ":":18}{v}', file=myfile) | |
def generator(self, inpt): | |
"""create the generator graph""" | |
# first dense layer | |
fc1 = tf.contrib.layers.fully_connected( | |
inpt, | |
self.params['genL1_nodes'], | |
activation_fn=tf.nn.relu, | |
scope='generator.1', | |
reuse=tf.AUTO_REUSE) | |
# second dense layer | |
fc2 = tf.contrib.layers.fully_connected( | |
fc1, | |
self.params['genL2_nodes'], | |
activation_fn=tf.nn.relu, | |
scope='generator.2', | |
reuse=tf.AUTO_REUSE) | |
# third dense layer | |
fc3 = tf.contrib.layers.fully_connected( | |
fc2, | |
self.params['genL3_nodes'], | |
activation_fn=tf.nn.sigmoid, | |
scope='generator.3', | |
reuse=tf.AUTO_REUSE) | |
return fc3 | |
### END of def generator(self, inpt) ### | |
def critic(self, x): | |
"""create the critic graph""" | |
# create first dense layer | |
fc1 = tf.contrib.layers.fully_connected( | |
x, | |
self.params['critL1_nodes'], | |
activation_fn=tf.nn.leaky_relu, | |
scope='critic.1', | |
reuse=tf.AUTO_REUSE) | |
# create second dense layer | |
fc2 = tf.contrib.layers.fully_connected( | |
fc1, | |
self.params['critL2_nodes'], | |
activation_fn=tf.nn.leaky_relu, | |
scope='critic.2', | |
reuse=tf.AUTO_REUSE) | |
# create third dense layer | |
fc3 = tf.contrib.layers.fully_connected( | |
fc2, | |
self.params['critL3_nodes'], | |
activation_fn=tf.nn.leaky_relu, | |
scope='critic.3', | |
reuse=tf.AUTO_REUSE) | |
# create fourth dense layer (output layer) | |
output = tf.contrib.layers.fully_connected( | |
fc3, | |
1, | |
activation_fn=None, | |
scope='critic.4', | |
reuse=tf.AUTO_REUSE) | |
return output | |
### END of def critic(self, output) ### | |
def dp_noise(self): | |
"""the noise for the critic gradients for differential privacy""" | |
# standard deviation of the Gaussian noise added to the gradients | |
sigma = 2 * self.params['DP_sigma'] | |
dpn = [tf.random.normal(shape, | |
mean=0.0, | |
stddev=sigma, | |
dtype=tf.float32) | |
for shape in self.crit_grad_shapes] | |
return dpn | |
def create_graph(self): | |
"""create computation graph""" | |
print("\n*** creating graph ***\n") | |
self.print_settings() | |
# create the placeholder for real data and generator for fake | |
self.real_data = tf.placeholder( | |
dtype=tf.float32, | |
shape=[self.params['batch_size'], self.params['n_features']], | |
name="RealData") | |
# create input noise for generator | |
noise = tf.random_uniform( | |
shape=[self.params['batch_size'],self.params['genL0_nodes']], | |
minval=0, | |
maxval=1, | |
dtype=tf.float32, | |
seed=None, | |
name="InputNoise") | |
# create fake data by using the noise as input to the generator | |
self.fake_data = self.generator(noise) | |
# run the critic for both types of data | |
self.crit_real = self.critic(self.real_data) # D(real) | |
self.crit_fake = self.critic(self.fake_data) # D( G(noise) ) | |
# create the loss for generator and critic | |
# L = D( G(noise) ) - D(real) + gradient penalty | |
# Adversarial game: max_G { min_D { L } } | |
# generator wants to max L, i.e. min -D( G(noise) ) | |
self.gen_loss = -tf.reduce_mean(self.crit_fake) | |
# critic wants to min L | |
self.crit_fake_loss = tf.reduce_mean(self.crit_fake) | |
self.crit_real_loss = -tf.reduce_mean(self.crit_real) | |
# the 'main' loss, i.e. before adding the gradient penalty term | |
self.crit_main_loss = self.crit_fake_loss + self.crit_real_loss | |
# add the gradient penalty to crit loss, at points that are randomly | |
# located along straight lines between the real and fake data | |
alpha = tf.random_uniform( | |
shape=[self.params['batch_size'], 1], minval=0, maxval=1) | |
# points between real and fake (line 6 of WGAN-GP algorithm) | |
interpolates = (alpha * self.real_data) + (1 - alpha)*self.fake_data | |
# compute gradients of dicriminator at interpolates | |
gradients = tf.gradients( | |
self.critic(interpolates), | |
[interpolates] | |
)[0] | |
# calculate the 2 norm of the gradients | |
slopes = tf.sqrt( | |
tf.reduce_sum(tf.square(gradients), reduction_indices=[1]) | |
) | |
# subtract 1, square | |
self.crit_grad_penalties = (slopes - 1.)**2 | |
# the gradient penalty term of the loss function, with weight 'lamda' | |
self.crit_gp_loss = ( | |
self.params['lamda'] * tf.reduce_mean(self.crit_grad_penalties) | |
) | |
# the entire loss function of the critic | |
self.crit_loss = ( | |
self.crit_main_loss | |
+ self.crit_gp_loss | |
) | |
### use Adam optimizer ### | |
# get TF trainable params that have 'generator' in name: | |
self.gen_params = [ | |
v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) | |
if 'generator' in v.name | |
] | |
# compute the gradient of gen_loss for those params | |
self.gen_grads = tf.train.AdamOptimizer( | |
learning_rate=1e-4, beta1=0.5, beta2=0.9).compute_gradients( | |
self.gen_loss, var_list=self.gen_params) | |
# minimize using gradients for gen_loss | |
self.gen_train_op = tf.train.AdamOptimizer( | |
learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize( | |
self.gen_loss, var_list=self.gen_params) | |
# get TF trainable params that have 'critic' in name: | |
self.crit_params = [ | |
v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) | |
if 'critic' in v.name | |
] | |
# compute the gradients for those params: | |
critOpt = tf.train.AdamOptimizer( | |
learning_rate=1e-4, beta1=0.5, beta2=0.9) | |
# for crit_real_loss | |
self.crit_real_grads = critOpt.compute_gradients( | |
self.crit_real_loss, var_list=self.crit_params) | |
# for crit_fake_loss | |
self.crit_fake_grads = critOpt.compute_gradients( | |
self.crit_fake_loss, var_list=self.crit_params) | |
# for crit_gp_loss | |
# this doesn't work!!! because the gp_grad of bias is None | |
self.crit_gp_grads = critOpt.compute_gradients( | |
self.crit_gp_loss, var_list=self.crit_params) | |
# for crit_loss | |
self.crit_grads = critOpt.compute_gradients( | |
self.crit_loss, var_list=self.crit_params) | |
############################################### | |
# add noise to grads for differential privacy # | |
############################################### | |
# get Gaussian noise | |
self.dpn = self.dp_noise() | |
# add it to the critic gradients | |
self.noisy_crit_grads = [(gv[0] + self.dpn[n], gv[1]) | |
for n, gv | |
in enumerate(self.crit_grads) | |
] | |
# record L2 sensitivity, to later compute the privacy loss | |
ib = tf.constant( | |
np.concatenate( | |
(np.zeros( | |
shape=(self.params['n_features'],1), | |
dtype=np.float32), | |
np.ones( | |
shape=(self.params['n_features'],1), | |
dtype=np.float32) | |
), | |
axis=1 | |
), | |
dtype=tf.float32, | |
name='InputBound') | |
self.critic_wts = [v for v in self.crit_params if 'weights' in v.name] | |
self.critic_bias = [v for v in self.crit_params if 'biases' in v.name] | |
alphas = (0.2,0.2,0.2,1.0) | |
self.bds = computeBounds(inputBound=ib, | |
weights=self.critic_wts, | |
bias=self.critic_bias, | |
alphas=alphas, | |
compGP_g=True) | |
self.biasBds = [tf.norm(t[:,1]-t[:,0]) for t in self.bds[2]] | |
self.wtBds = [tf.norm(t[:,:,1]-t[:,:,0]) for t in self.bds[3]] | |
self.gpBds = self.bds[4] | |
self.sensitivity = ( | |
tf.sqrt( | |
tf.reduce_sum(tf.square(self.biasBds)) | |
+ tf.reduce_sum(tf.square(self.wtBds)) | |
) | |
+ self.params['lamda'] * | |
tf.sqrt(tf.reduce_sum(tf.square(self.gpBds))) | |
) | |
# minimize using the noisy gradients for crit_loss | |
self.crit_train_op = critOpt.apply_gradients(self.noisy_crit_grads) | |
### END of Adam optimizer code ### | |
# Generating synthetic data after training | |
rand_noise = tf.random_uniform( | |
shape=[self.params['S_synth'],self.params['genL0_nodes']], | |
minval=0, | |
maxval=1, | |
dtype=tf.float32, | |
seed=None, | |
name="NoiseForSamples") | |
self.rand_noise_samples = self.generator(rand_noise) | |
self.samples_scores = self.critic(self.rand_noise_samples) | |
self.real_samples = tf.placeholder( | |
dtype=tf.float32, | |
shape=[self.params['S_synth'], self.params['n_features']], | |
name="RealDataForScoring") | |
self.real_scores = self.critic(self.real_samples) | |
# save session graph information | |
with tf.Session() as session: | |
_ = tf.summary.FileWriter('./logs_new', session.graph) | |
### END def create_graph(self) ### | |
def train(self): | |
"""run the training loop""" | |
saver = tf.train.Saver() # saver object for saving the model | |
sessStart = time.time() # time at start of training session | |
sessStartDT = datetime.now() # datetime at start of training session | |
sessDTstring = sessStartDT.strftime('%Y%m%d_%H%M') | |
############################ | |
# START TensorFlow session # | |
############################ | |
with tf.Session() as session: | |
# initialize variables | |
session.run(tf.global_variables_initializer()) | |
myvars = tf.trainable_variables() | |
print("\n*** starting to train at "+sessDTstring+" ***\n") | |
print("\nTensorFlow trainable parameters:") | |
for myvar in myvars: | |
print(myvar.name) | |
print(' shape =', myvar.get_shape()) | |
### START of training loop ### | |
for epoch in range(self.params['num_epochs']): | |
start_time = time.time() | |
# create lists to store the critic losses this epoch | |
crit_fake_loss_list = [] | |
crit_real_loss_list = [] | |
crit_loss_list = [] | |
crit_grads_list = [] | |
dpn_list = [] | |
noisy_crit_grads_list = [] | |
sensitivity_list = [] | |
### START of critic iterations ### | |
for i in range(self.params['critic_iters']): | |
# get a batch | |
train = next(self.train_batcher) | |
############################# | |
# run one critic iteration # | |
############################# | |
# compute the losses (which are NOT diff. private) | |
# and execute train_op (which IS diff. private) | |
(crit_fake_loss, | |
crit_real_loss, | |
crit_loss, | |
crit_grads, | |
dpn, | |
noisy_crit_grads, | |
sensitivity, | |
_) = session.run( | |
[self.crit_fake_loss, | |
self.crit_real_loss, | |
self.crit_loss, | |
self.crit_grads, | |
self.dpn, | |
self.noisy_crit_grads, | |
self.sensitivity, | |
self.crit_train_op], | |
feed_dict={self.real_data: train}) | |
# append the losses to their lists | |
crit_fake_loss_list.append(crit_fake_loss) | |
crit_real_loss_list.append(crit_real_loss) | |
crit_loss_list.append(crit_loss) | |
sensitivity_list.append(sensitivity) | |
if epoch == 0 and i == 0: | |
print("\n*** Completed one critic training iteration!" | |
+ " ***\n") | |
# if recordGrads: record the gradients & noise added | |
if recordGrads and (epoch == 0 | |
or epoch % recGradEvery == (recGradEvery - 1) | |
): | |
crit_grads_list.append(crit_grads) | |
dpn_list.append(dpn) | |
noisy_crit_grads_list.append(noisy_crit_grads) | |
### END of critic iterations ### | |
# if at epoch ending 999 check test loss | |
if epoch == 0 or epoch % 1000 == 999: | |
# shuffle test_data in place | |
np.random.shuffle(self.test_data) | |
# get the total loss and the main_loss (w/o GP) | |
test_crit_loss, test_crit_main_loss = session.run( | |
[self.crit_loss, | |
self.crit_main_loss], | |
feed_dict={ | |
self.real_data: | |
self.test_data[:self.params['batch_size']] | |
} | |
) | |
# append the losses to their lists | |
self.crit_loss_test_all.append(test_crit_loss) | |
self.crit_main_loss_test_all.append(test_crit_main_loss) | |
# print the test loss to console | |
print(f'Test Epoch: [Test D loss: {test_crit_loss:7.4f}]') | |
###################################### | |
# run one generator train iteration # | |
###################################### | |
if epoch < (self.params['num_epochs'] - 1): | |
gen_loss, _ = session.run([self.gen_loss, | |
self.gen_train_op]) | |
if epoch == 0: | |
print("\n*** Completed one generator training" | |
+ " iteration! ***\n") | |
# save the loss and time of epoch | |
self.time_all.append(time.time() - start_time) | |
self.crit_fake_loss_all.append(crit_fake_loss_list) | |
self.crit_real_loss_all.append(crit_real_loss_list) | |
self.crit_loss_all.append(crit_loss_list) | |
self.sensitivity_all.append(sensitivity_list) | |
self.gen_loss_all.append(gen_loss) | |
# if recordGrads: record the gradients & noise added | |
if recordGrads and (epoch == 0 | |
or epoch % recGradEvery == (recGradEvery - 1) | |
): | |
self.crit_grads_all.append(crit_grads_list) | |
self.dpn_all.append(dpn_list) | |
self.noisy_crit_grads_all.append(noisy_crit_grads_list) | |
if epoch < 10 or epoch % 100 == 99: | |
# print the results | |
print(f'Epoch: {epoch:5}', | |
f'[D loss: {self.crit_loss_all[-1][-1]:7.4f}]', | |
f'[G loss: {self.gen_loss_all[-1]:7.4f}]', | |
f'[Time: {self.time_all[-1]:4.2f}]') | |
#################### | |
# AFTER last epoch # | |
#################### | |
print("\n*** Done training!!! ***\n") | |
# generate samples of fake data | |
if self.params['N_synth']: | |
print( | |
f"Generating {self.params['N_synth']} " | |
+ "samples of synthetic data.\n" | |
) | |
sample_files = [] | |
score_files = [] | |
real_score_files = [] | |
for i in range(self.params['N_synth']): | |
sample_files.append( | |
outDIR | |
+ 'samples_' | |
+ sessDTstring | |
+ f'_synth_{i}.csv') | |
score_files.append( | |
outDIR | |
+ 'scores_' | |
+ sessDTstring | |
+ f'_synth_{i}.csv') | |
real_score_files.append( | |
outDIR | |
+ 'scores_' | |
+ sessDTstring | |
+ f'_real_{i}.csv') | |
samples = session.run(self.rand_noise_samples) | |
samples = pd.DataFrame(samples, columns=self.col_names) | |
samples.to_csv( | |
outDIR | |
+ 'samples_' | |
+ sessDTstring | |
+ f'_synth_{i}.csv', | |
index=False) | |
scores = session.run(self.samples_scores) | |
scores = pd.DataFrame(scores) | |
scores.to_csv( | |
outDIR | |
+ 'scores_' | |
+ sessDTstring | |
+ f'_synth_{i}.csv', | |
index=False) | |
j = i*self.params['S_synth'] | |
k = (i+1)*self.params['S_synth'] | |
if (k <= self.params['batch_size']): | |
real_scores = session.run( | |
self.real_scores, | |
feed_dict={self.real_samples: | |
self.train_data[j:k,:]} | |
) | |
real_scores = pd.DataFrame(real_scores) | |
real_scores.to_csv( | |
outDIR | |
+ 'scores_' | |
+ sessDTstring | |
+ f'_real_{i}.csv', | |
index=False) | |
# dump record of time, loss functions, etc. | |
dump_file = outDIR + 'dump_' + sessDTstring + '.pkl' | |
print("Dumping loss, etc., in " + dump_file + "\n") | |
dump_dict = { | |
'time': self.time_all, | |
'crit_fake_loss': self.crit_fake_loss_all, | |
'crit_real_loss': self.crit_real_loss_all, | |
'crit_main_loss': self.crit_main_loss_all, | |
'crit_loss': self.crit_loss_all, | |
'gen_loss': self.gen_loss_all, | |
'test_main_loss': self.crit_main_loss_test_all, | |
'test_loss': self.crit_loss_test_all, | |
'sensitivity': self.sensitivity_all | |
} | |
if recordGrads: | |
dump_dict['crit_grads'] = self.crit_grads_all | |
dump_dict['dpn'] = self.dpn_all | |
dump_dict['noisy_crit_grads'] = self.noisy_crit_grads_all | |
with open(dump_file, 'wb') as f: | |
pkl.dump(dump_dict,f) | |
# save the TensorFlow session | |
print("Saving Tensorflow session.\n") | |
session_file = os.path.join(outDIR, | |
'model_' + sessDTstring + '.ckpt') | |
saver.save(session, session_file) | |
# write the graph to file | |
graph_file = outDIR + 'wgan_graph_' + sessDTstring + '.pbtxt' | |
tf.io.write_graph(session.graph, '.', graph_file) | |
# record the total time for training in a "myLog" .txt file | |
sessStop = time.time() | |
sessStopDT = datetime.now() | |
sessStopString = sessStopDT.strftime('%Y%m%d_%H%M') | |
totalTime = sessStop - sessStart | |
########################### | |
### Start of myLog file ### | |
########################### | |
log_file = outDIR + 'myLog' + sessDTstring + '.txt' | |
myfile = open(log_file, 'w', encoding='utf-8') | |
print("Writing log to " + log_file + "\n") | |
print(sys.argv[0] + "\n", file=myfile) | |
print(f"Started at {sessDTstring}\n" | |
+ f"Stopped at {sessStopString}\n" | |
+ f"Total time = {totalTime//3600}h" | |
+ f" {(totalTime%3600)//60}m" | |
+ f" {(totalTime%60)}s\n", | |
file=myfile) | |
self.print_settings(myfile) | |
print("\nTensorFlow trainable parameters:", file=myfile) | |
for myvar in myvars: | |
print(myvar.name, file=myfile) | |
print(' shape =', myvar.get_shape(), file=myfile) | |
print(f"\nTensorflow session saved to:\n{session_file} files", | |
file=myfile) | |
print(f"Graph written to:\n{graph_file}", file=myfile) | |
print(f"\nThe following was dumped in:\n{dump_file}", file=myfile) | |
for k in dump_dict.keys(): | |
print(k, file=myfile) | |
print("\nThe following synthetic data samples were generated:", | |
file=myfile) | |
for i in range(len(sample_files)): | |
print(' ' + sample_files[i], file=myfile) | |
print('\n\n ######################################\n', | |
' #\n', | |
'The entire script used this session: #\n', | |
' #\n', | |
'######################################\n\n', | |
file=myfile) | |
with open(__file__) as f: | |
print(f.read(), file=myfile) | |
myfile.close() | |
### End of myLog file ### | |
############################# | |
# END of TensorFlow session # | |
############################# | |
### END of def train(self) ### | |
### END of class WGAN() ### | |
#------------------------------------------------------------------------------ | |
if __name__ == '__main__': | |
# start with a fresh graph | |
tf.reset_default_graph() | |
# create object | |
wgan = WGAN( | |
outDIR = outDIR, | |
trng_path = thisDIR + "train_sdv.csv", | |
test_path = thisDIR + "test_sdv.csv", | |
genL0_nodes = 100, | |
critL1_nodes = 64, | |
critic_iters = 5, | |
num_epochs = 2000, | |
batch_size = 208, | |
lamda = 10, | |
N_synth = 2, | |
S_synth = 2085, | |
DP_sigma = 1e-6 | |
) | |
wgan.create_graph() # define the computation graph | |
wgan.train() # train the model | |
print("*** ALL DONE at " + datetime.now().strftime('%Y%m%d_%H%M') + " ***\n") |