jree423 commited on
Commit
c5799f5
·
verified ·
1 Parent(s): 87fe160

Upload DiffSketcher/run_painterly_render.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. DiffSketcher/run_painterly_render.py +131 -0
DiffSketcher/run_painterly_render.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing
3
+ # Description: the main func of this project.
4
+ # Copyright (c) 2023, XiMing Xing.
5
+ # License: MIT License
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ from datetime import datetime
11
+ import random
12
+ from typing import Any, List
13
+ from functools import partial
14
+
15
+ from accelerate.utils import set_seed
16
+ import omegaconf
17
+
18
+ sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
19
+
20
+ from libs.engine import merge_and_update_config
21
+ from libs.utils.argparse import accelerate_parser, base_data_parser
22
+
23
+
24
+ def render_batch_wrap(args: omegaconf.DictConfig,
25
+ seed_range: List,
26
+ pipeline: Any,
27
+ **pipe_args):
28
+ start_time = datetime.now()
29
+ for idx, seed in enumerate(seed_range):
30
+ args.seed = seed # update seed
31
+ print(f"\n-> [{idx}/{len(seed_range)}], "
32
+ f"current seed: {seed}, "
33
+ f"current time: {datetime.now() - start_time}\n")
34
+ pipe = pipeline(args)
35
+ pipe.painterly_rendering(**pipe_args)
36
+
37
+
38
+ def main(args, seed_range):
39
+ args.batch_size = 1 # rendering one SVG at a time
40
+
41
+ args.width = float(args.width)
42
+
43
+ render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range)
44
+
45
+ if args.task == "diffsketcher": # text2sketch
46
+ from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
47
+
48
+ if not args.render_batch:
49
+ pipe = DiffSketcherPipeline(args)
50
+ pipe.painterly_rendering(args.prompt)
51
+ else: # generate many SVG at once
52
+ render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt)
53
+
54
+ elif args.task == "style-diffsketcher": # text2sketch + style transfer
55
+ from pipelines.painter.diffsketcher_stylized_pipeline import StylizedDiffSketcherPipeline
56
+
57
+ if not args.render_batch:
58
+ pipe = StylizedDiffSketcherPipeline(args)
59
+ pipe.painterly_rendering(args.prompt, args.style_file)
60
+ else: # generate many SVG at once
61
+ render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=args.prompt, style_fpath=args.style_file)
62
+
63
+
64
+ if __name__ == '__main__':
65
+ parser = argparse.ArgumentParser(
66
+ description="vary style and content painterly rendering",
67
+ parents=[accelerate_parser(), base_data_parser()]
68
+ )
69
+ # flag
70
+ parser.add_argument("-tk", "--task",
71
+ default="diffsketcher", type=str,
72
+ choices=['diffsketcher', 'style-diffsketcher'],
73
+ help="choose a method.")
74
+ # config
75
+ parser.add_argument("-c", "--config",
76
+ required=True, type=str,
77
+ default="",
78
+ help="YAML/YML file for configuration.")
79
+ parser.add_argument("-style", "--style_file",
80
+ default="", type=str,
81
+ help="the path of style img place.")
82
+ # prompt
83
+ parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str)
84
+ parser.add_argument("-npt", "--negative_prompt", default="", type=str)
85
+ # DiffSVG
86
+ parser.add_argument("--print_timing", "-timing", action="store_true",
87
+ help="set print svg rendering timing.")
88
+ # diffuser
89
+ parser.add_argument("--download", action="store_true",
90
+ help="download models from huggingface automatically.")
91
+ parser.add_argument("--force_download", "-download", action="store_true",
92
+ help="force the models to be downloaded from huggingface.")
93
+ parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
94
+ help="download the models again from the breakpoint.")
95
+ # rendering quantity
96
+ # like: python main.py -rdbz -srange 100 200
97
+ parser.add_argument("--render_batch", "-rdbz", action="store_true")
98
+ parser.add_argument("-srange", "--seed_range",
99
+ required=False, nargs='+',
100
+ help="Sampling quantity.")
101
+ # visual rendering process
102
+ parser.add_argument("-mv", "--make_video", action="store_true",
103
+ help="make a video of the rendering process.")
104
+ parser.add_argument("-frame_freq", "--video_frame_freq",
105
+ default=1, type=int,
106
+ help="video frame control.")
107
+ parser.add_argument("-framerate", "--video_frame_rate",
108
+ default=36, type=int,
109
+ help="by adjusting the frame rate, you can control the playback speed of the output video.")
110
+
111
+ args = parser.parse_args()
112
+
113
+ # set the random seed range
114
+ seed_range = None
115
+ if args.render_batch:
116
+ # random sampling without specifying a range
117
+ start_, end_ = 1, 1000000
118
+ if args.seed_range is not None: # specify range sequential sampling
119
+ seed_range_ = list(args.seed_range)
120
+ assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0])
121
+ start_, end_ = int(seed_range_[0]), int(seed_range_[1])
122
+ seed_range = [i for i in range(start_, end_)]
123
+ else:
124
+ # a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000])
125
+ numbers = list(range(start_, end_))
126
+ seed_range = random.sample(numbers, k=1000)
127
+
128
+ args = merge_and_update_config(args)
129
+
130
+ set_seed(args.seed)
131
+ main(args, seed_range)