Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tensorflow as tf
|
3 |
+
# Load compressed models from tensorflow_hub
|
4 |
+
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
|
5 |
+
import IPython.display as display
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib as mpl
|
9 |
+
mpl.rcParams['figure.figsize'] = (12, 12)
|
10 |
+
mpl.rcParams['axes.grid'] = False
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import PIL.Image
|
14 |
+
|
15 |
+
def tensor_to_image(tensor):
|
16 |
+
tensor = tensor*255
|
17 |
+
tensor = np.array(tensor, dtype=np.uint8)
|
18 |
+
if np.ndim(tensor)>3:
|
19 |
+
assert tensor.shape[0] == 1
|
20 |
+
tensor = tensor[0]
|
21 |
+
return PIL.Image.fromarray(tensor)
|
22 |
+
|
23 |
+
def load_img(path_to_img):
|
24 |
+
max_dim = 1024
|
25 |
+
img = tf.io.read_file(path_to_img)
|
26 |
+
img = tf.image.decode_image(img, channels=3)
|
27 |
+
img = tf.image.convert_image_dtype(img, tf.float32)
|
28 |
+
|
29 |
+
shape = tf.cast(tf.shape(img)[:-1], tf.float32)
|
30 |
+
long_dim = max(shape)
|
31 |
+
scale = max_dim / long_dim
|
32 |
+
|
33 |
+
new_shape = tf.cast(shape * scale, tf.int32)
|
34 |
+
|
35 |
+
img = tf.image.resize(img, new_shape)
|
36 |
+
img = img[tf.newaxis, :]
|
37 |
+
return img
|
38 |
+
|
39 |
+
def imshow(image, title=None):
|
40 |
+
if len(image.shape) > 3:
|
41 |
+
image = tf.squeeze(image, axis=0)
|
42 |
+
|
43 |
+
plt.imshow(image)
|
44 |
+
if title:
|
45 |
+
plt.title(title)
|
46 |
+
|
47 |
+
content_layers = ['block5_conv2']
|
48 |
+
|
49 |
+
style_layers = ['block1_conv1',
|
50 |
+
'block2_conv1',
|
51 |
+
'block3_conv1',
|
52 |
+
'block4_conv1',
|
53 |
+
'block5_conv1']
|
54 |
+
|
55 |
+
num_content_layers = len(content_layers)
|
56 |
+
num_style_layers = len(style_layers)
|
57 |
+
|
58 |
+
def vgg_layers(layer_names):
|
59 |
+
""" Creates a vgg model that returns a list of intermediate output values."""
|
60 |
+
# Load our model. Load pretrained VGG, trained on imagenet data
|
61 |
+
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
|
62 |
+
vgg.trainable = False
|
63 |
+
|
64 |
+
outputs = [vgg.get_layer(name).output for name in layer_names]
|
65 |
+
|
66 |
+
model = tf.keras.Model([vgg.input], outputs)
|
67 |
+
return model
|
68 |
+
|
69 |
+
def gram_matrix(input_tensor):
|
70 |
+
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
|
71 |
+
input_shape = tf.shape(input_tensor)
|
72 |
+
num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
|
73 |
+
return result/(num_locations)
|
74 |
+
|
75 |
+
class StyleContentModel(tf.keras.models.Model):
|
76 |
+
def __init__(self, style_layers, content_layers):
|
77 |
+
super(StyleContentModel, self).__init__()
|
78 |
+
self.vgg = vgg_layers(style_layers + content_layers)
|
79 |
+
self.style_layers = style_layers
|
80 |
+
self.content_layers = content_layers
|
81 |
+
self.num_style_layers = len(style_layers)
|
82 |
+
self.vgg.trainable = False
|
83 |
+
|
84 |
+
def call(self, inputs):
|
85 |
+
"Expects float input in [0,1]"
|
86 |
+
inputs = inputs*255.0
|
87 |
+
preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
|
88 |
+
outputs = self.vgg(preprocessed_input)
|
89 |
+
style_outputs, content_outputs = (outputs[:self.num_style_layers],
|
90 |
+
outputs[self.num_style_layers:])
|
91 |
+
|
92 |
+
style_outputs = [gram_matrix(style_output)
|
93 |
+
for style_output in style_outputs]
|
94 |
+
|
95 |
+
content_dict = {content_name: value
|
96 |
+
for content_name, value
|
97 |
+
in zip(self.content_layers, content_outputs)}
|
98 |
+
|
99 |
+
style_dict = {style_name: value
|
100 |
+
for style_name, value
|
101 |
+
in zip(self.style_layers, style_outputs)}
|
102 |
+
|
103 |
+
return {'content': content_dict, 'style': style_dict}
|
104 |
+
|
105 |
+
extractor = StyleContentModel(style_layers, content_layers)
|
106 |
+
|
107 |
+
def clip_0_1(image):
|
108 |
+
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)
|
109 |
+
|
110 |
+
def high_pass_x_y(image):
|
111 |
+
x_var = image[:, :, 1:, :] - image[:, :, :-1, :]
|
112 |
+
y_var = image[:, 1:, :, :] - image[:, :-1, :, :]
|
113 |
+
|
114 |
+
return x_var, y_var
|
115 |
+
|
116 |
+
def total_variation_loss(image):
|
117 |
+
x_deltas, y_deltas = high_pass_x_y(image)
|
118 |
+
return tf.reduce_sum(tf.abs(x_deltas)) + tf.reduce_sum(tf.abs(y_deltas))
|
119 |
+
|
120 |
+
opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
|
121 |
+
|
122 |
+
style_weight=1e-2
|
123 |
+
content_weight=1e4
|
124 |
+
total_variation_weight=30
|
125 |
+
|
126 |
+
epochs = 10
|
127 |
+
steps_per_epoch = 50
|
128 |
+
|
129 |
+
def transfer_style(content_path,style_path,transfer_mode,steps_per_epoch=100,style_weight=1e-2,content_weight=1e4,total_variation_weight=30):
|
130 |
+
try:
|
131 |
+
|
132 |
+
content_image = load_img(content_path)
|
133 |
+
style_image = load_img(style_path)
|
134 |
+
if transfer_mode == "Fast_transfer":
|
135 |
+
res = transfer_style_fast(content_image,style_image)
|
136 |
+
else:
|
137 |
+
res = transfer_style_custom(content_image,style_image,int(steps_per_epoch),style_weight,content_weight,total_variation_weight)
|
138 |
+
res = tensor_to_image(res)
|
139 |
+
except Exception as ex:
|
140 |
+
raise Exception(ex)
|
141 |
+
return res
|
142 |
+
|
143 |
+
def transfer_style_fast(content_image,style_image):
|
144 |
+
import tensorflow_hub as hub
|
145 |
+
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
|
146 |
+
return hub_model(tf.constant(content_image), tf.constant(style_image))[0]
|
147 |
+
|
148 |
+
def transfer_style_custom(content_image,style_image,steps_per_epoch=100,style_weight=1e-2,content_weight=1e4,total_variation_weight=30):
|
149 |
+
|
150 |
+
def style_content_loss(outputs):
|
151 |
+
style_outputs = outputs['style']
|
152 |
+
content_outputs = outputs['content']
|
153 |
+
style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2)
|
154 |
+
for name in style_outputs.keys()])
|
155 |
+
style_loss *= style_weight / num_style_layers
|
156 |
+
|
157 |
+
content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2)
|
158 |
+
for name in content_outputs.keys()])
|
159 |
+
content_loss *= content_weight / num_content_layers
|
160 |
+
loss = style_loss + content_loss
|
161 |
+
return loss
|
162 |
+
|
163 |
+
@tf.function()
|
164 |
+
def train_step(image):
|
165 |
+
with tf.GradientTape() as tape:
|
166 |
+
outputs = extractor(image)
|
167 |
+
loss = style_content_loss(outputs)
|
168 |
+
loss += total_variation_weight*tf.image.total_variation(image)
|
169 |
+
|
170 |
+
grad = tape.gradient(loss, image)
|
171 |
+
opt.apply_gradients([(grad, image)])
|
172 |
+
image.assign(clip_0_1(image))
|
173 |
+
try:
|
174 |
+
style_targets = extractor(style_image)['style']
|
175 |
+
content_targets = extractor(content_image)['content']
|
176 |
+
image = tf.Variable(content_image)
|
177 |
+
|
178 |
+
step = 0
|
179 |
+
for n in range(epochs):
|
180 |
+
for m in range(steps_per_epoch):
|
181 |
+
step += 1
|
182 |
+
train_step(image)
|
183 |
+
except Exception as ex:
|
184 |
+
raise Exception(ex)
|
185 |
+
|
186 |
+
return image
|
187 |
+
|
188 |
+
import gradio as gr
|
189 |
+
|
190 |
+
inputs = [
|
191 |
+
gr.inputs.Image(type="filepath"),
|
192 |
+
gr.inputs.Image(type="filepath"),
|
193 |
+
gr.inputs.Radio(["Fast_transfer","Custom_transfer"]),
|
194 |
+
gr.inputs.Slider(1,100,default=30,step=1),
|
195 |
+
gr.inputs.Number(1e-2),
|
196 |
+
gr.inputs.Number(1e4),
|
197 |
+
gr.inputs.Number(30)
|
198 |
+
]
|
199 |
+
|
200 |
+
iface = gr.Interface(
|
201 |
+
fn=transfer_style,
|
202 |
+
inputs=inputs,
|
203 |
+
examples=[["NST/etsii.jpg","NST/data/style_2.jpg","Fast_transfer",30,1e-2,1e4,30],
|
204 |
+
["NST/data/content_9.jpg","NST/ola.png","Fast_transfer",30,1e-2,1e4,30],
|
205 |
+
["NST/sailboat_cropped.jpg","NST/sketch_cropped.png","Fast_transfer",30,1e-2,1e4,30],
|
206 |
+
["NST/armadillo.jpg","NST/data/style_3.jpg","Fast_transfer",30,1e-2,1e4,30],
|
207 |
+
["NST/gato.jpg","NST/data/style_4.jpg","Fast_transfer",30,1e-2,1e4,30],
|
208 |
+
],
|
209 |
+
outputs="image").launch(share=True)
|