Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# -*- coding: utf-8 -*-
"""
Differentially Private
Wasserstein GAN with gradient penalty
@author: Yale et al., ESANN 2019
adapted by Joseph Pedersen
1 2 3 4 5 6 7
1234567890123456789012345678901234567890123456789012345678901234567890123456789
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sys
import time
from datetime import datetime
import pickle as pkl
import numpy as np
import pandas as pd
import tensorflow as tf
print('tf',tf.__version__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
from tf_wgan_bds import computeBounds
myHome = "/home/username/"
thisDIR = myHome + "dataset1/"
outDIR = thisDIR + "output/"
recordGrads = False # record the gradients
recGradEvery = 10000 # number of epochs between recording gradients
def data_batcher(data, batch_size):
"""create yield function for given data and batch size"""
def get_all_batches():
"""yield function (generator) for all batches in data"""
# shuffle in place each time
np.random.shuffle(data)
# get total number of evenly divisible batchs
# shape of (num_batches, batch_size, n_features)
batches = data[:(data.shape[0] // batch_size) * batch_size]
batches = batches.reshape(-1, batch_size, data.shape[1])
# go through all batches and yield them
for i, _ in enumerate(batches):
yield np.copy(batches[i])
def infinite_data_batcher():
"""a generator that yields new batches every time it is called"""
# once we run out of batches start over
while True:
# for each batch in one set of batches
for batch in get_all_batches():
yield batch
return infinite_data_batcher()
class WGAN():
"""Wasserstein GAN with gradient penalties"""
def __init__(self,
outDIR=None,
trng_path=None,
test_path=None,
genL0_nodes=None,
critL1_nodes=None,
critic_iters=None,
num_epochs=None,
batch_size=None,
lamda=None,
N_synth=None,
S_synth=None,
DP_sigma=None
):
self.params = {
'outDIR': None, # path to directory to write log and dump saves
'trng_path': None, # path/filename of training data
'test_path': None, # path/filename of test data
'genL0_nodes': 100, # size of random vector for input to generator
'critL1_nodes': 64, # num nodes in 1st hidden layer of critc
'critic_iters': 5, # num critic iters per generator iter
'num_epochs': 100000, # num generator iters
'batch_size': None, # batch size for training
'lamda': 10, # multiplier for gradient penalty term of loss fn
'N_synth': 0, # num of synthetic samples to generate after trng
'S_synth': 10000, # size of synthetic samples to generate
'DP_sigma': 0 # standard deviation to use for DP noise
}
# path to save the output, and to the training/test data
if not outDIR:
raise ValueError("Where should output be saved?")
self.params['outDIR'] = outDIR
if (not trng_path) or (not test_path):
raise ValueError("Where is the data?")
self.params['trng_path'] = trng_path
self.params['test_path'] = test_path
# read in data
scratch = pd.read_csv(trng_path)
self.col_names = scratch.columns
self.train_data = scratch.values
self.test_data = pd.read_csv(test_path).values
###################################
# setting the sizes of the layers #
###################################
# size of random vector for input to generator
if genL0_nodes:
self.params['genL0_nodes'] = genL0_nodes
# layers of generator
self.params['n_features'] = self.train_data.shape[1]
self.params['genL1_nodes'] = 2 * self.params['n_features']
self.params['genL2_nodes'] = 3 * self.params['n_features'] // 2
self.params['genL3_nodes'] = self.params['n_features']
# layers of critic
if critL1_nodes:
self.params['critL1_nodes'] = critL1_nodes
self.params['critL2_nodes'] = 2 * self.params['critL1_nodes']
self.params['critL3_nodes'] = 4 * self.params['critL1_nodes']
# save critic shapes for adding noise
self.crit_grad_shapes = [
(self.params['n_features'], self.params['critL1_nodes']),
(self.params['critL1_nodes'], ),
(self.params['critL1_nodes'], self.params['critL2_nodes']),
(self.params['critL2_nodes'], ),
(self.params['critL2_nodes'], self.params['critL3_nodes']),
(self.params['critL3_nodes'], ),
(self.params['critL3_nodes'], 1),
(1, )
]
###################################
# setting other hyper-parameters #
###################################
# number of critic trng iterations for each generator trng iteration
if critic_iters: self.params['critic_iters'] = critic_iters
# number of generator iterations (also called epochs)
if num_epochs: self.params['num_epochs'] = num_epochs
self.params['n_observations'] = self.train_data.shape[0]
if batch_size:
self.params['batch_size'] = batch_size
else:
# batch_size is num observations divided by num of critic
# iterations rounded down to the nearest multiple of 100.
self.params['batch_size'] = (
(int(self.train_data.shape[0] /
self.params['critic_iters'])
// 100)
* 100)
# weight for gradient penalty term of loss function
if lamda:
self.params['lamda'] = lamda
# number & size of synthetic samples to generate
if N_synth:
self.params['N_synth'] = N_synth
if S_synth:
self.params['S_synth'] = S_synth
# differential privacy noise parameter
if DP_sigma:
self.params['DP_sigma'] = DP_sigma
# double check on sizing
assert self.test_data.shape[0] > self.params['batch_size']
assert (self.train_data.shape[0] / self.params['batch_size']
>= self.params['critic_iters'])
# define the method train_batcher using the def of data_batcher
self.train_batcher = data_batcher(self.train_data,
self.params['batch_size'])
##############################
# declaring other attributes #
# that will be set later #
##############################
self.real_data = None # real data
self.fake_data = None # fake data = gen(noise)
self.crit_real = None # crit(real_data)
self.crit_fake = None # crit(fake_data)
self.crit_grad_penalties = None # grad penalty term
# loss function terms
self.gen_loss = None # -mean(crit_fake)
self.crit_real_loss = None # -mean(crit_real)
self.crit_fake_loss = None # mean(crit_fake)
self.crit_main_loss = None # mean(crit_fake) - mean(crit_real)
self.crit_gp_loss = None # grad penalty term of loss function
self.crit_loss = None # crit_main_loss + crit_gp_loss
# training parameters, gradients, and training operations
self.gen_params = None # TF trainable params for generator
self.gen_grads = None # gradient of those params
self.gen_train_op = None # AdamOptimizer.minimize
self.crit_params = None # TF trainable params for critic
self.critic_wts = None # the weight matrices for the critic
self.critic_bias = None # the biases for the critic
self.crit_real_grads = None # grads for crit_real_loss
self.crit_fake_grads = None # grads for crit_fake_loss
self.crit_main_grads = None # grads for crit_main_loss
self.crit_gp_grads = None # grads for crit_gp_loss
self.crit_grads = None # grads for crit_loss
self.dpn = None # noise added for differential privacy
self.noisy_crit_grads = None # noisy grads for crit_loss
self.bds = None # the computed bounds on the gradients
self.wtBds = None # norm of bounds on gradients of weights, by layer
self.biasBds = None # norm of bounds on gradients of bias, by layer
self.gpBds = None # bounds on norm of gradients of gp term, by layer
self.sensitivity = None # the total L2 sensitivity each critic iter
self.crit_train_op = None
# noise for generating synthetic data after training
self.rand_noise_samples = None
self.samples_scores = None # critic scores of synthetic samples
self.real_samples = None # samples of real data for comparison
self.real_scores = None # critic scores of the real samples
##############################
# define lists to store data #
##############################
# loss function values during training
self.gen_loss_all = []
self.crit_fake_loss_all = []
self.crit_real_loss_all = []
self.crit_main_loss_all = []
self.crit_loss_all = []
# loss on test set
self.crit_main_loss_test_all = []
self.crit_loss_test_all = []
# time
self.time_all = []
# gradients during training
self.gen_grads_all = []
self.crit_real_grads_all = []
self.crit_fake_grads_all = []
self.crit_gp_grads_all = []
self.crit_grads_all = []
self.dpn_all = []
self.noisy_crit_grads_all = []
self.wtBds_all = []
self.biasBds_all = []
self.gpBds_all = []
self.sensitivity_all = []
### END of def __init__(self) ###
def print_settings(self, myfile=None):
"""print the WGAN parameters"""
print("\nWGAN parameters:", file=myfile)
for k, v in self.params.items():
print(f'{k + ":":18}{v}', file=myfile)
def generator(self, inpt):
"""create the generator graph"""
# first dense layer
fc1 = tf.contrib.layers.fully_connected(
inpt,
self.params['genL1_nodes'],
activation_fn=tf.nn.relu,
scope='generator.1',
reuse=tf.AUTO_REUSE)
# second dense layer
fc2 = tf.contrib.layers.fully_connected(
fc1,
self.params['genL2_nodes'],
activation_fn=tf.nn.relu,
scope='generator.2',
reuse=tf.AUTO_REUSE)
# third dense layer
fc3 = tf.contrib.layers.fully_connected(
fc2,
self.params['genL3_nodes'],
activation_fn=tf.nn.sigmoid,
scope='generator.3',
reuse=tf.AUTO_REUSE)
return fc3
### END of def generator(self, inpt) ###
def critic(self, x):
"""create the critic graph"""
# create first dense layer
fc1 = tf.contrib.layers.fully_connected(
x,
self.params['critL1_nodes'],
activation_fn=tf.nn.leaky_relu,
scope='critic.1',
reuse=tf.AUTO_REUSE)
# create second dense layer
fc2 = tf.contrib.layers.fully_connected(
fc1,
self.params['critL2_nodes'],
activation_fn=tf.nn.leaky_relu,
scope='critic.2',
reuse=tf.AUTO_REUSE)
# create third dense layer
fc3 = tf.contrib.layers.fully_connected(
fc2,
self.params['critL3_nodes'],
activation_fn=tf.nn.leaky_relu,
scope='critic.3',
reuse=tf.AUTO_REUSE)
# create fourth dense layer (output layer)
output = tf.contrib.layers.fully_connected(
fc3,
1,
activation_fn=None,
scope='critic.4',
reuse=tf.AUTO_REUSE)
return output
### END of def critic(self, output) ###
def dp_noise(self):
"""the noise for the critic gradients for differential privacy"""
# standard deviation of the Gaussian noise added to the gradients
sigma = 2 * self.params['DP_sigma']
dpn = [tf.random.normal(shape,
mean=0.0,
stddev=sigma,
dtype=tf.float32)
for shape in self.crit_grad_shapes]
return dpn
def create_graph(self):
"""create computation graph"""
print("\n*** creating graph ***\n")
self.print_settings()
# create the placeholder for real data and generator for fake
self.real_data = tf.placeholder(
dtype=tf.float32,
shape=[self.params['batch_size'], self.params['n_features']],
name="RealData")
# create input noise for generator
noise = tf.random_uniform(
shape=[self.params['batch_size'],self.params['genL0_nodes']],
minval=0,
maxval=1,
dtype=tf.float32,
seed=None,
name="InputNoise")
# create fake data by using the noise as input to the generator
self.fake_data = self.generator(noise)
# run the critic for both types of data
self.crit_real = self.critic(self.real_data) # D(real)
self.crit_fake = self.critic(self.fake_data) # D( G(noise) )
# create the loss for generator and critic
# L = D( G(noise) ) - D(real) + gradient penalty
# Adversarial game: max_G { min_D { L } }
# generator wants to max L, i.e. min -D( G(noise) )
self.gen_loss = -tf.reduce_mean(self.crit_fake)
# critic wants to min L
self.crit_fake_loss = tf.reduce_mean(self.crit_fake)
self.crit_real_loss = -tf.reduce_mean(self.crit_real)
# the 'main' loss, i.e. before adding the gradient penalty term
self.crit_main_loss = self.crit_fake_loss + self.crit_real_loss
# add the gradient penalty to crit loss, at points that are randomly
# located along straight lines between the real and fake data
alpha = tf.random_uniform(
shape=[self.params['batch_size'], 1], minval=0, maxval=1)
# points between real and fake (line 6 of WGAN-GP algorithm)
interpolates = (alpha * self.real_data) + (1 - alpha)*self.fake_data
# compute gradients of dicriminator at interpolates
gradients = tf.gradients(
self.critic(interpolates),
[interpolates]
)[0]
# calculate the 2 norm of the gradients
slopes = tf.sqrt(
tf.reduce_sum(tf.square(gradients), reduction_indices=[1])
)
# subtract 1, square
self.crit_grad_penalties = (slopes - 1.)**2
# the gradient penalty term of the loss function, with weight 'lamda'
self.crit_gp_loss = (
self.params['lamda'] * tf.reduce_mean(self.crit_grad_penalties)
)
# the entire loss function of the critic
self.crit_loss = (
self.crit_main_loss
+ self.crit_gp_loss
)
### use Adam optimizer ###
# get TF trainable params that have 'generator' in name:
self.gen_params = [
v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if 'generator' in v.name
]
# compute the gradient of gen_loss for those params
self.gen_grads = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.5, beta2=0.9).compute_gradients(
self.gen_loss, var_list=self.gen_params)
# minimize using gradients for gen_loss
self.gen_train_op = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(
self.gen_loss, var_list=self.gen_params)
# get TF trainable params that have 'critic' in name:
self.crit_params = [
v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if 'critic' in v.name
]
# compute the gradients for those params:
critOpt = tf.train.AdamOptimizer(
learning_rate=1e-4, beta1=0.5, beta2=0.9)
# for crit_real_loss
self.crit_real_grads = critOpt.compute_gradients(
self.crit_real_loss, var_list=self.crit_params)
# for crit_fake_loss
self.crit_fake_grads = critOpt.compute_gradients(
self.crit_fake_loss, var_list=self.crit_params)
# for crit_gp_loss
# this doesn't work!!! because the gp_grad of bias is None
self.crit_gp_grads = critOpt.compute_gradients(
self.crit_gp_loss, var_list=self.crit_params)
# for crit_loss
self.crit_grads = critOpt.compute_gradients(
self.crit_loss, var_list=self.crit_params)
###############################################
# add noise to grads for differential privacy #
###############################################
# get Gaussian noise
self.dpn = self.dp_noise()
# add it to the critic gradients
self.noisy_crit_grads = [(gv[0] + self.dpn[n], gv[1])
for n, gv
in enumerate(self.crit_grads)
]
# record L2 sensitivity, to later compute the privacy loss
ib = tf.constant(
np.concatenate(
(np.zeros(
shape=(self.params['n_features'],1),
dtype=np.float32),
np.ones(
shape=(self.params['n_features'],1),
dtype=np.float32)
),
axis=1
),
dtype=tf.float32,
name='InputBound')
self.critic_wts = [v for v in self.crit_params if 'weights' in v.name]
self.critic_bias = [v for v in self.crit_params if 'biases' in v.name]
alphas = (0.2,0.2,0.2,1.0)
self.bds = computeBounds(inputBound=ib,
weights=self.critic_wts,
bias=self.critic_bias,
alphas=alphas,
compGP_g=True)
self.biasBds = [tf.norm(t[:,1]-t[:,0]) for t in self.bds[2]]
self.wtBds = [tf.norm(t[:,:,1]-t[:,:,0]) for t in self.bds[3]]
self.gpBds = self.bds[4]
self.sensitivity = (
tf.sqrt(
tf.reduce_sum(tf.square(self.biasBds))
+ tf.reduce_sum(tf.square(self.wtBds))
)
+ self.params['lamda'] *
tf.sqrt(tf.reduce_sum(tf.square(self.gpBds)))
)
# minimize using the noisy gradients for crit_loss
self.crit_train_op = critOpt.apply_gradients(self.noisy_crit_grads)
### END of Adam optimizer code ###
# Generating synthetic data after training
rand_noise = tf.random_uniform(
shape=[self.params['S_synth'],self.params['genL0_nodes']],
minval=0,
maxval=1,
dtype=tf.float32,
seed=None,
name="NoiseForSamples")
self.rand_noise_samples = self.generator(rand_noise)
self.samples_scores = self.critic(self.rand_noise_samples)
self.real_samples = tf.placeholder(
dtype=tf.float32,
shape=[self.params['S_synth'], self.params['n_features']],
name="RealDataForScoring")
self.real_scores = self.critic(self.real_samples)
# save session graph information
with tf.Session() as session:
_ = tf.summary.FileWriter('./logs_new', session.graph)
### END def create_graph(self) ###
def train(self):
"""run the training loop"""
saver = tf.train.Saver() # saver object for saving the model
sessStart = time.time() # time at start of training session
sessStartDT = datetime.now() # datetime at start of training session
sessDTstring = sessStartDT.strftime('%Y%m%d_%H%M')
############################
# START TensorFlow session #
############################
with tf.Session() as session:
# initialize variables
session.run(tf.global_variables_initializer())
myvars = tf.trainable_variables()
print("\n*** starting to train at "+sessDTstring+" ***\n")
print("\nTensorFlow trainable parameters:")
for myvar in myvars:
print(myvar.name)
print(' shape =', myvar.get_shape())
### START of training loop ###
for epoch in range(self.params['num_epochs']):
start_time = time.time()
# create lists to store the critic losses this epoch
crit_fake_loss_list = []
crit_real_loss_list = []
crit_loss_list = []
crit_grads_list = []
dpn_list = []
noisy_crit_grads_list = []
sensitivity_list = []
### START of critic iterations ###
for i in range(self.params['critic_iters']):
# get a batch
train = next(self.train_batcher)
#############################
# run one critic iteration #
#############################
# compute the losses (which are NOT diff. private)
# and execute train_op (which IS diff. private)
(crit_fake_loss,
crit_real_loss,
crit_loss,
crit_grads,
dpn,
noisy_crit_grads,
sensitivity,
_) = session.run(
[self.crit_fake_loss,
self.crit_real_loss,
self.crit_loss,
self.crit_grads,
self.dpn,
self.noisy_crit_grads,
self.sensitivity,
self.crit_train_op],
feed_dict={self.real_data: train})
# append the losses to their lists
crit_fake_loss_list.append(crit_fake_loss)
crit_real_loss_list.append(crit_real_loss)
crit_loss_list.append(crit_loss)
sensitivity_list.append(sensitivity)
if epoch == 0 and i == 0:
print("\n*** Completed one critic training iteration!"
+ " ***\n")
# if recordGrads: record the gradients & noise added
if recordGrads and (epoch == 0
or epoch % recGradEvery == (recGradEvery - 1)
):
crit_grads_list.append(crit_grads)
dpn_list.append(dpn)
noisy_crit_grads_list.append(noisy_crit_grads)
### END of critic iterations ###
# if at epoch ending 999 check test loss
if epoch == 0 or epoch % 1000 == 999:
# shuffle test_data in place
np.random.shuffle(self.test_data)
# get the total loss and the main_loss (w/o GP)
test_crit_loss, test_crit_main_loss = session.run(
[self.crit_loss,
self.crit_main_loss],
feed_dict={
self.real_data:
self.test_data[:self.params['batch_size']]
}
)
# append the losses to their lists
self.crit_loss_test_all.append(test_crit_loss)
self.crit_main_loss_test_all.append(test_crit_main_loss)
# print the test loss to console
print(f'Test Epoch: [Test D loss: {test_crit_loss:7.4f}]')
######################################
# run one generator train iteration #
######################################
if epoch < (self.params['num_epochs'] - 1):
gen_loss, _ = session.run([self.gen_loss,
self.gen_train_op])
if epoch == 0:
print("\n*** Completed one generator training"
+ " iteration! ***\n")
# save the loss and time of epoch
self.time_all.append(time.time() - start_time)
self.crit_fake_loss_all.append(crit_fake_loss_list)
self.crit_real_loss_all.append(crit_real_loss_list)
self.crit_loss_all.append(crit_loss_list)
self.sensitivity_all.append(sensitivity_list)
self.gen_loss_all.append(gen_loss)
# if recordGrads: record the gradients & noise added
if recordGrads and (epoch == 0
or epoch % recGradEvery == (recGradEvery - 1)
):
self.crit_grads_all.append(crit_grads_list)
self.dpn_all.append(dpn_list)
self.noisy_crit_grads_all.append(noisy_crit_grads_list)
if epoch < 10 or epoch % 100 == 99:
# print the results
print(f'Epoch: {epoch:5}',
f'[D loss: {self.crit_loss_all[-1][-1]:7.4f}]',
f'[G loss: {self.gen_loss_all[-1]:7.4f}]',
f'[Time: {self.time_all[-1]:4.2f}]')
####################
# AFTER last epoch #
####################
print("\n*** Done training!!! ***\n")
# generate samples of fake data
if self.params['N_synth']:
print(
f"Generating {self.params['N_synth']} "
+ "samples of synthetic data.\n"
)
sample_files = []
score_files = []
real_score_files = []
for i in range(self.params['N_synth']):
sample_files.append(
outDIR
+ 'samples_'
+ sessDTstring
+ f'_synth_{i}.csv')
score_files.append(
outDIR
+ 'scores_'
+ sessDTstring
+ f'_synth_{i}.csv')
real_score_files.append(
outDIR
+ 'scores_'
+ sessDTstring
+ f'_real_{i}.csv')
samples = session.run(self.rand_noise_samples)
samples = pd.DataFrame(samples, columns=self.col_names)
samples.to_csv(
outDIR
+ 'samples_'
+ sessDTstring
+ f'_synth_{i}.csv',
index=False)
scores = session.run(self.samples_scores)
scores = pd.DataFrame(scores)
scores.to_csv(
outDIR
+ 'scores_'
+ sessDTstring
+ f'_synth_{i}.csv',
index=False)
j = i*self.params['S_synth']
k = (i+1)*self.params['S_synth']
if (k <= self.params['batch_size']):
real_scores = session.run(
self.real_scores,
feed_dict={self.real_samples:
self.train_data[j:k,:]}
)
real_scores = pd.DataFrame(real_scores)
real_scores.to_csv(
outDIR
+ 'scores_'
+ sessDTstring
+ f'_real_{i}.csv',
index=False)
# dump record of time, loss functions, etc.
dump_file = outDIR + 'dump_' + sessDTstring + '.pkl'
print("Dumping loss, etc., in " + dump_file + "\n")
dump_dict = {
'time': self.time_all,
'crit_fake_loss': self.crit_fake_loss_all,
'crit_real_loss': self.crit_real_loss_all,
'crit_main_loss': self.crit_main_loss_all,
'crit_loss': self.crit_loss_all,
'gen_loss': self.gen_loss_all,
'test_main_loss': self.crit_main_loss_test_all,
'test_loss': self.crit_loss_test_all,
'sensitivity': self.sensitivity_all
}
if recordGrads:
dump_dict['crit_grads'] = self.crit_grads_all
dump_dict['dpn'] = self.dpn_all
dump_dict['noisy_crit_grads'] = self.noisy_crit_grads_all
with open(dump_file, 'wb') as f:
pkl.dump(dump_dict,f)
# save the TensorFlow session
print("Saving Tensorflow session.\n")
session_file = os.path.join(outDIR,
'model_' + sessDTstring + '.ckpt')
saver.save(session, session_file)
# write the graph to file
graph_file = outDIR + 'wgan_graph_' + sessDTstring + '.pbtxt'
tf.io.write_graph(session.graph, '.', graph_file)
# record the total time for training in a "myLog" .txt file
sessStop = time.time()
sessStopDT = datetime.now()
sessStopString = sessStopDT.strftime('%Y%m%d_%H%M')
totalTime = sessStop - sessStart
###########################
### Start of myLog file ###
###########################
log_file = outDIR + 'myLog' + sessDTstring + '.txt'
myfile = open(log_file, 'w', encoding='utf-8')
print("Writing log to " + log_file + "\n")
print(sys.argv[0] + "\n", file=myfile)
print(f"Started at {sessDTstring}\n"
+ f"Stopped at {sessStopString}\n"
+ f"Total time = {totalTime//3600}h"
+ f" {(totalTime%3600)//60}m"
+ f" {(totalTime%60)}s\n",
file=myfile)
self.print_settings(myfile)
print("\nTensorFlow trainable parameters:", file=myfile)
for myvar in myvars:
print(myvar.name, file=myfile)
print(' shape =', myvar.get_shape(), file=myfile)
print(f"\nTensorflow session saved to:\n{session_file} files",
file=myfile)
print(f"Graph written to:\n{graph_file}", file=myfile)
print(f"\nThe following was dumped in:\n{dump_file}", file=myfile)
for k in dump_dict.keys():
print(k, file=myfile)
print("\nThe following synthetic data samples were generated:",
file=myfile)
for i in range(len(sample_files)):
print(' ' + sample_files[i], file=myfile)
print('\n\n ######################################\n',
' #\n',
'The entire script used this session: #\n',
' #\n',
'######################################\n\n',
file=myfile)
with open(__file__) as f:
print(f.read(), file=myfile)
myfile.close()
### End of myLog file ###
#############################
# END of TensorFlow session #
#############################
### END of def train(self) ###
### END of class WGAN() ###
#------------------------------------------------------------------------------
if __name__ == '__main__':
# start with a fresh graph
tf.reset_default_graph()
# create object
wgan = WGAN(
outDIR = outDIR,
trng_path = thisDIR + "train_sdv.csv",
test_path = thisDIR + "test_sdv.csv",
genL0_nodes = 100,
critL1_nodes = 64,
critic_iters = 5,
num_epochs = 2000,
batch_size = 208,
lamda = 10,
N_synth = 2,
S_synth = 2085,
DP_sigma = 1e-6
)
wgan.create_graph() # define the computation graph
wgan.train() # train the model
print("*** ALL DONE at " + datetime.now().strftime('%Y%m%d_%H%M') + " ***\n")