marlenezw's picture
more changes to the third party lib.
dd38ad1
import tensorflow as tf
from deep_heatmaps_model_primary_valid import DeepHeatmapsModel
import os
import numpy as np
num_tests = 10
params = np.logspace(-8, -2, num_tests)
max_iter = 80000
output_dir = 'tests_lr_fusion'
data_dir = '../conventional_landmark_detection_dataset'
flags = tf.app.flags
flags.DEFINE_string('output_dir', output_dir, "directory for saving the log file")
flags.DEFINE_string('img_path', data_dir, "data directory")
FLAGS = flags.FLAGS
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
for param in params:
test_name = str(param)
test_dir = os.path.join(FLAGS.output_dir,test_name)
if not os.path.exists(test_dir):
os.mkdir(test_dir)
print '##### RUNNING TESTS ##### current directory:', test_dir
save_model_path = os.path.join(test_dir, 'model')
save_sample_path = os.path.join(test_dir, 'sample')
save_log_path = os.path.join(test_dir, 'logs')
# create directories if not exist
if not os.path.exists(save_model_path):
os.mkdir(save_model_path)
if not os.path.exists(save_sample_path):
os.mkdir(save_sample_path)
if not os.path.exists(save_log_path):
os.mkdir(save_log_path)
tf.reset_default_graph() # reset graph
model = DeepHeatmapsModel(mode='TRAIN', train_iter=max_iter, learning_rate=param, momentum=0.95, step=80000,
gamma=0.1, batch_size=4, image_size=256, c_dim=3, num_landmarks=68,
augment_basic=True, basic_start=0, augment_texture=True, p_texture=0.5,
augment_geom=True, p_geom=0.5, artistic_start=0, artistic_step=10,
img_path=FLAGS.img_path, save_log_path=save_log_path, save_sample_path=save_sample_path,
save_model_path=save_model_path)
model.train()