Spaces:
Build error
Build error
| import torch | |
| import argparse | |
| import numpy as np | |
| from helper import * | |
| from config.GlobalVariables import * | |
| from SynthesisNetwork import SynthesisNetwork | |
| from DataLoader import DataLoader | |
| import convenience | |
| import gradio as gr | |
| #@title Demo | |
| device = 'cpu' | |
| num_samples = 10 | |
| net = SynthesisNetwork(weight_dim=256, num_layers=3).to(device) | |
| if not torch.cuda.is_available(): | |
| try: # retrained model also contains loss in dict | |
| net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))["model_state_dict"]) | |
| except: | |
| net.load_state_dict(torch.load('./model/250000.pt', map_location=torch.device(device))) | |
| dl = DataLoader(num_writer=1, num_samples=10, divider=5.0, datadir='./data/writers') | |
| writer_options = [5, 14, 15, 16, 17, 22, 25, 80, 120, 137, 147, 151] | |
| all_loaded_data = [] | |
| avail_char = "0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z A B C D E F G H I J K L M N O P Q R S T U V W X Y Z ! ? \" ' * + - = : ; , . < > \ / [ ] ( ) # $ % &" | |
| avail_char_list = avail_char.split(" ") | |
| for writer_id in [120, 80]: | |
| loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples))) | |
| all_loaded_data.append(loaded_data) | |
| default_loaded_data = all_loaded_data[-1] | |
| # for writer interpolation | |
| def interpolate_writers(target_word, weight): | |
| image = convenience.sample_blended_writers([1 - weight, weight], target_word, net, all_loaded_data, device).convert("RGB") | |
| return image | |
| def choose_blend_writers(writer1, writer2): | |
| id1, id2 = int(writer1.split(" ")[1]), int(writer1.split(" ")[1]) | |
| all_loaded_data.clear() | |
| for writer_id in [id1, id2]: | |
| loaded_data = dl.next_batch(TYPE='TRAIN', uid=writer_id, tids=list(range(num_samples))) | |
| all_loaded_data.append(loaded_data) | |
| return gr.Slider.update(label=f"{writer1} vs. {writer2}") | |
| ''' | |
| def choose_writer(writ, c1, c2, c3, c4, val): | |
| all_loaded_data.clear() | |
| w = int(writ.split(" ")[1]) | |
| loaded_data = dl.next_batch(TYPE='TRAIN', uid=w, tids=list(range(num_samples))) | |
| all_loaded_data.append(loaded_data) | |
| return char_grid(c1, c2, c3, c4, val) | |
| ''' | |
| # for character grrid | |
| def choose_grid_chars(c1, c2, c3, c4): | |
| return gr.Button.update(value=f"Blend {c1}, {c2}, {c3}, and {c4}!") | |
| def char_grid(c1, c2, c3, c4): | |
| image = convenience.sample_character_grid([c1, c2, c3, c4], 5, net, [default_loaded_data], device).convert("RGB") | |
| return image | |
| # for character blend | |
| def interpolate_chars(c1, c2, weight): | |
| image = convenience.sample_blended_chars([1 - weight, weight], [c1, c2], net, [default_loaded_data], device).convert("RGB") | |
| return image | |
| def choose_blend_chars(c1, c2): | |
| return gr.Slider.update(label=f"'{c1}' vs. '{c2}'") | |
| # for MDN | |
| def mdn_sample(word, maxs, maxr): | |
| image = convenience.mdn_single_sample(word, maxs, maxr, net, [default_loaded_data], device).convert("RGB") | |
| return image | |
| """ | |
| def char_vid(word): | |
| #make word list | |
| convenience.char_interpolation_video(word_list, 10, net, [default_loaded_data], device).convert('RGB') | |
| vid_path = f"/content/drive/MyDrive/Colab Notebooks/Spring22/decoupled-style-descriptors-eb/results/abcdefg_video.mov" | |
| return gr.Video.update(value=vid_path) | |
| """ | |
| with gr.Blocks() as demo: | |
| with gr.Tabs(): | |
| with gr.TabItem("Blend Writers"): | |
| target_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1) | |
| with gr.Row(): | |
| left_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 0] | |
| right_ratio_options = ["Style " + str(id) for i, id in enumerate(writer_options) if i % 2 == 1] | |
| with gr.Column(): | |
| writer1 = gr.Radio(left_ratio_options, value="Style 120", label="Style for first writer") | |
| with gr.Column(): | |
| writer2 = gr.Radio(right_ratio_options, value="Style 80", label="Style for second writer") | |
| with gr.Row(): | |
| writer_slider = gr.Slider(0, 1, value=0.3, label="Style 120 vs. Style 80") | |
| with gr.Row(): | |
| writer_submit = gr.Button("Submit") | |
| with gr.Row(): | |
| writer_default_image = convenience.sample_blended_writers([0.7, 0.3], "hello world", net, all_loaded_data, device).convert("RGB") | |
| writer_output = gr.Image(writer_default_image) | |
| writer_submit.click(fn=interpolate_writers, inputs=[target_word, writer_slider], outputs=[writer_output]) | |
| writer_slider.change(fn=interpolate_writers, inputs=[target_word, writer_slider], outputs=[writer_output]) | |
| target_word.submit(fn=interpolate_writers, inputs=[target_word, writer_slider], outputs=[writer_output]) | |
| writer1.change(fn=choose_blend_writers, inputs=[writer1, writer2], outputs=[writer_slider]) | |
| writer2.change(fn=choose_blend_writers, inputs=[writer1, writer2], outputs=[writer_slider]) | |
| with gr.TabItem("Blend Characters"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| char1 = gr.Dropdown(choices=avail_char_list, value="y", label="Character 1") | |
| with gr.Column(): | |
| char2 = gr.Dropdown(choices=avail_char_list, value="s", label="Character 2") | |
| with gr.Row(): | |
| char_slider = gr.Slider(0, 1, value=0.3, label="'y' vs. 's'") | |
| with gr.Row(): | |
| char_default_image = convenience.sample_blended_chars([0.7, 0.3], ["y", "s"], net, [default_loaded_data], device).convert("RGB") | |
| char_output = gr.Image(char_default_image) | |
| char_slider.change(fn=interpolate_chars, inputs=[char1, char2, char_slider], outputs=[char_output]) | |
| char1.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider]) | |
| char2.change(fn=choose_blend_chars, inputs=[char1, char2], outputs=[char_slider]) | |
| """ | |
| with gr.TabItem("Character Grid"): #slow | |
| with gr.Row(): | |
| with gr.Column(): | |
| char1 = gr.Dropdown(choices=avail_char_list, value="y", label="Character 1") | |
| with gr.Column(): | |
| char2 = gr.Dropdown(choices=avail_char_list, value="s", label="Character 2") | |
| with gr.Column(): | |
| char3 = gr.Dropdown(choices=avail_char_list, value="u", label="Character 3") | |
| with gr.Column(): | |
| char4 = gr.Dropdown(choices=avail_char_list, value="n", label="Character 4") | |
| with gr.Row(): | |
| submit_button = gr.Button(value="Blend y, s, u, and n!") | |
| ''' | |
| with gr.Row(): | |
| radio_options2 = ["Writer " + str(n) for n in writer_options] | |
| writer = gr.Radio(radio_options2, value="Writer 80", label="Style for Writer") | |
| writer.change(fn=choose_writer, inputs=[writer, char1, char2, char3, char4, slider2], outputs=[output]) | |
| ''' | |
| #slider2 = gr.Slider(2, 20, value=10, label="Grid Size", step=1) | |
| default_image = convenience.sample_character_grid(['y', 's', 'u', 'n'], 10, net, [default_loaded_data], device).convert("RGB") | |
| output = gr.Image(default_image) | |
| char1.change(fn=choose_grid_chars, inputs=[char1, char2, char3, char4], outputs=[submit_button]) | |
| char2.change(fn=choose_grid_chars, inputs=[char1, char2, char3, char4], outputs=[submit_button]) | |
| char3.change(fn=choose_grid_chars, inputs=[char1, char2, char3, char4], outputs=[submit_button]) | |
| char4.change(fn=choose_grid_chars, inputs=[char1, char2, char3, char4], outputs=[submit_button]) | |
| #slider2.change(fn=char_grid, inputs=[char1, char2, char3, char4, slider2], outputs=[output]) | |
| submit_button.click(fn=char_grid, inputs=[char1, char2, char3, char4], outputs=[output]) | |
| """ | |
| with gr.TabItem("Add Randomness"): | |
| mdn_word = gr.Textbox(label="Target Word", value="hello world", max_lines=1) | |
| ''' | |
| with gr.Row(): | |
| radio_options3 = ["Writer " + str(n) for n in writer_options] | |
| writer = gr.Radio(radio_options3, value="Writer 80", label="Style for Writer") | |
| writer.change(fn=new_writer_mdn, inputs=[writer, slider3, slider4], outputs=[output]) | |
| ''' | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_rand = gr.Slider(0, 1, value=1, label="Maximum Randomness") | |
| with gr.Column(): | |
| scale_rand = gr.Slider(0, 3, value=0.5, label="Scale of Randomness") | |
| with gr.Row(): | |
| sample_button = gr.Button(value="Resample!") | |
| with gr.Row(): | |
| default_im = convenience.mdn_single_sample("hello world", 0.5, 1, net, [default_loaded_data], device).convert('RGB') | |
| mdn_output = gr.Image(default_im) | |
| max_rand.change(fn=mdn_sample, inputs=[mdn_word, scale_rand, max_rand], outputs=[mdn_output]) | |
| scale_rand.change(fn=mdn_sample, inputs=[mdn_word, scale_rand, max_rand], outputs=[mdn_output]) | |
| sample_button.click(fn=mdn_sample, inputs=[mdn_word, scale_rand, max_rand], outputs=[mdn_output]) | |
| mdn_word.submit(fn=mdn_sample, inputs=[mdn_word, scale_rand, max_rand], outputs=[mdn_output]) | |
| demo.launch() | |