File size: 1,913 Bytes
dd38ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import tensorflow as tf
from deep_heatmaps_model_primary_valid import DeepHeatmapsModel
import os
import numpy as np

num_tests = 10
params = np.logspace(-8, -2, num_tests)
max_iter = 80000

output_dir = 'tests_lr_fusion'
data_dir = '../conventional_landmark_detection_dataset'

flags = tf.app.flags
flags.DEFINE_string('output_dir', output_dir, "directory for saving the log file")
flags.DEFINE_string('img_path', data_dir, "data directory")
FLAGS = flags.FLAGS

if not os.path.exists(FLAGS.output_dir):
    os.mkdir(FLAGS.output_dir)

for param in params:
    test_name = str(param)
    test_dir = os.path.join(FLAGS.output_dir,test_name)
    if not os.path.exists(test_dir):
        os.mkdir(test_dir)

    print '##### RUNNING TESTS ##### current directory:', test_dir

    save_model_path = os.path.join(test_dir, 'model')
    save_sample_path = os.path.join(test_dir, 'sample')
    save_log_path = os.path.join(test_dir, 'logs')

    # create directories if not exist
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)
    if not os.path.exists(save_sample_path):
        os.mkdir(save_sample_path)
    if not os.path.exists(save_log_path):
        os.mkdir(save_log_path)

    tf.reset_default_graph()  # reset graph

    model = DeepHeatmapsModel(mode='TRAIN', train_iter=max_iter, learning_rate=param, momentum=0.95, step=80000,
                              gamma=0.1, batch_size=4, image_size=256, c_dim=3, num_landmarks=68,
                              augment_basic=True, basic_start=0, augment_texture=True, p_texture=0.5,
                              augment_geom=True, p_geom=0.5, artistic_start=0, artistic_step=10,
                              img_path=FLAGS.img_path, save_log_path=save_log_path, save_sample_path=save_sample_path,
                              save_model_path=save_model_path)

    model.train()