Upload DiffSketcher/run_painterly_render.py with huggingface_hub
Browse files
DiffSketcher/run_painterly_render.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing
|
3 |
+
# Description: the main func of this project.
|
4 |
+
# Copyright (c) 2023, XiMing Xing.
|
5 |
+
# License: MIT License
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from datetime import datetime
|
11 |
+
import random
|
12 |
+
from typing import Any, List
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
from accelerate.utils import set_seed
|
16 |
+
import omegaconf
|
17 |
+
|
18 |
+
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
|
19 |
+
|
20 |
+
from libs.engine import merge_and_update_config
|
21 |
+
from libs.utils.argparse import accelerate_parser, base_data_parser
|
22 |
+
|
23 |
+
|
24 |
+
def render_batch_wrap(args: omegaconf.DictConfig,
|
25 |
+
seed_range: List,
|
26 |
+
pipeline: Any,
|
27 |
+
**pipe_args):
|
28 |
+
start_time = datetime.now()
|
29 |
+
for idx, seed in enumerate(seed_range):
|
30 |
+
args.seed = seed # update seed
|
31 |
+
print(f"\n-> [{idx}/{len(seed_range)}], "
|
32 |
+
f"current seed: {seed}, "
|
33 |
+
f"current time: {datetime.now() - start_time}\n")
|
34 |
+
pipe = pipeline(args)
|
35 |
+
pipe.painterly_rendering(**pipe_args)
|
36 |
+
|
37 |
+
|
38 |
+
def main(args, seed_range):
|
39 |
+
args.batch_size = 1 # rendering one SVG at a time
|
40 |
+
|
41 |
+
args.width = float(args.width)
|
42 |
+
|
43 |
+
render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range)
|
44 |
+
|
45 |
+
if args.task == "diffsketcher": # text2sketch
|
46 |
+
from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
|
47 |
+
|
48 |
+
if not args.render_batch:
|
49 |
+
pipe = DiffSketcherPipeline(args)
|
50 |
+
pipe.painterly_rendering(args.prompt)
|
51 |
+
else: # generate many SVG at once
|
52 |
+
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt)
|
53 |
+
|
54 |
+
elif args.task == "style-diffsketcher": # text2sketch + style transfer
|
55 |
+
from pipelines.painter.diffsketcher_stylized_pipeline import StylizedDiffSketcherPipeline
|
56 |
+
|
57 |
+
if not args.render_batch:
|
58 |
+
pipe = StylizedDiffSketcherPipeline(args)
|
59 |
+
pipe.painterly_rendering(args.prompt, args.style_file)
|
60 |
+
else: # generate many SVG at once
|
61 |
+
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=args.prompt, style_fpath=args.style_file)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
parser = argparse.ArgumentParser(
|
66 |
+
description="vary style and content painterly rendering",
|
67 |
+
parents=[accelerate_parser(), base_data_parser()]
|
68 |
+
)
|
69 |
+
# flag
|
70 |
+
parser.add_argument("-tk", "--task",
|
71 |
+
default="diffsketcher", type=str,
|
72 |
+
choices=['diffsketcher', 'style-diffsketcher'],
|
73 |
+
help="choose a method.")
|
74 |
+
# config
|
75 |
+
parser.add_argument("-c", "--config",
|
76 |
+
required=True, type=str,
|
77 |
+
default="",
|
78 |
+
help="YAML/YML file for configuration.")
|
79 |
+
parser.add_argument("-style", "--style_file",
|
80 |
+
default="", type=str,
|
81 |
+
help="the path of style img place.")
|
82 |
+
# prompt
|
83 |
+
parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str)
|
84 |
+
parser.add_argument("-npt", "--negative_prompt", default="", type=str)
|
85 |
+
# DiffSVG
|
86 |
+
parser.add_argument("--print_timing", "-timing", action="store_true",
|
87 |
+
help="set print svg rendering timing.")
|
88 |
+
# diffuser
|
89 |
+
parser.add_argument("--download", action="store_true",
|
90 |
+
help="download models from huggingface automatically.")
|
91 |
+
parser.add_argument("--force_download", "-download", action="store_true",
|
92 |
+
help="force the models to be downloaded from huggingface.")
|
93 |
+
parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
|
94 |
+
help="download the models again from the breakpoint.")
|
95 |
+
# rendering quantity
|
96 |
+
# like: python main.py -rdbz -srange 100 200
|
97 |
+
parser.add_argument("--render_batch", "-rdbz", action="store_true")
|
98 |
+
parser.add_argument("-srange", "--seed_range",
|
99 |
+
required=False, nargs='+',
|
100 |
+
help="Sampling quantity.")
|
101 |
+
# visual rendering process
|
102 |
+
parser.add_argument("-mv", "--make_video", action="store_true",
|
103 |
+
help="make a video of the rendering process.")
|
104 |
+
parser.add_argument("-frame_freq", "--video_frame_freq",
|
105 |
+
default=1, type=int,
|
106 |
+
help="video frame control.")
|
107 |
+
parser.add_argument("-framerate", "--video_frame_rate",
|
108 |
+
default=36, type=int,
|
109 |
+
help="by adjusting the frame rate, you can control the playback speed of the output video.")
|
110 |
+
|
111 |
+
args = parser.parse_args()
|
112 |
+
|
113 |
+
# set the random seed range
|
114 |
+
seed_range = None
|
115 |
+
if args.render_batch:
|
116 |
+
# random sampling without specifying a range
|
117 |
+
start_, end_ = 1, 1000000
|
118 |
+
if args.seed_range is not None: # specify range sequential sampling
|
119 |
+
seed_range_ = list(args.seed_range)
|
120 |
+
assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0])
|
121 |
+
start_, end_ = int(seed_range_[0]), int(seed_range_[1])
|
122 |
+
seed_range = [i for i in range(start_, end_)]
|
123 |
+
else:
|
124 |
+
# a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000])
|
125 |
+
numbers = list(range(start_, end_))
|
126 |
+
seed_range = random.sample(numbers, k=1000)
|
127 |
+
|
128 |
+
args = merge_and_update_config(args)
|
129 |
+
|
130 |
+
set_seed(args.seed)
|
131 |
+
main(args, seed_range)
|