Commit
·
dcc4583
1
Parent(s):
7e8823d
move assignment of CHECKPOINTS_PATH to `app.py`
Browse files- app.py +5 -3
- download_checkpoints.py +9 -0
- helper.py +61 -59
- main.py +16 -10
app.py
CHANGED
@@ -31,15 +31,17 @@ if not check_environment():
|
|
31 |
|
32 |
|
33 |
from download_checkpoints import main as download_checkpoints
|
34 |
-
|
|
|
35 |
|
36 |
# download checkpoints
|
37 |
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
38 |
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
39 |
|
40 |
-
|
41 |
import gradio as gr
|
42 |
|
|
|
|
|
43 |
# Define the Gradio Blocks interface
|
44 |
with gr.Blocks() as demo:
|
45 |
gr.Markdown(
|
@@ -81,7 +83,7 @@ with gr.Blocks() as demo:
|
|
81 |
output_file = gr.File(label="Download Results")
|
82 |
|
83 |
generate_button.click(
|
84 |
-
fn=
|
85 |
inputs=[
|
86 |
rgb_video_input,
|
87 |
hdmap_input,
|
|
|
31 |
|
32 |
|
33 |
from download_checkpoints import main as download_checkpoints
|
34 |
+
|
35 |
+
CHECKPOINTS_PATH = "/data/checkpoints"
|
36 |
|
37 |
# download checkpoints
|
38 |
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
39 |
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
40 |
|
|
|
41 |
import gradio as gr
|
42 |
|
43 |
+
from helper import generate_video_fun
|
44 |
+
|
45 |
# Define the Gradio Blocks interface
|
46 |
with gr.Blocks() as demo:
|
47 |
gr.Markdown(
|
|
|
83 |
output_file = gr.File(label="Download Results")
|
84 |
|
85 |
generate_button.click(
|
86 |
+
fn=generate_video_fun(CHECKPOINTS_PATH),
|
87 |
inputs=[
|
88 |
rgb_video_input,
|
89 |
hdmap_input,
|
download_checkpoints.py
CHANGED
@@ -131,3 +131,12 @@ def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkp
|
|
131 |
# Download each checkpoint
|
132 |
for checkpoint in checkpoint_vars:
|
133 |
download_checkpoint(checkpoint, output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
# Download each checkpoint
|
132 |
for checkpoint in checkpoint_vars:
|
133 |
download_checkpoint(checkpoint, output_dir)
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
import sys
|
138 |
+
|
139 |
+
PWD = os.path.dirname(__file__)
|
140 |
+
CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
141 |
+
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
142 |
+
main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all")
|
helper.py
CHANGED
@@ -32,7 +32,6 @@ from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
|
|
32 |
from gpu_info import stop_watcher, watch_gpu_memory
|
33 |
|
34 |
PWD = os.path.dirname(__file__)
|
35 |
-
CHECKPOINTS_PATH = "/data/checkpoints"
|
36 |
LOG_DIR = os.path.join(PWD, "logs")
|
37 |
os.makedirs(LOG_DIR, exist_ok=True)
|
38 |
|
@@ -370,68 +369,71 @@ def create_zip_for_download(filename, files_to_zip):
|
|
370 |
import gradio as gr
|
371 |
|
372 |
|
373 |
-
def
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
)
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
controlnet_specs_in={
|
405 |
-
"hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
|
406 |
-
"lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
|
407 |
-
},
|
408 |
-
input_video_path=rgb_video_path,
|
409 |
-
checkpoint_dir=CHECKPOINTS_PATH,
|
410 |
-
prompt=prompt,
|
411 |
-
negative_prompt=negative_prompt,
|
412 |
-
sigma_max=80,
|
413 |
-
offload_text_encoder_model=True,
|
414 |
-
is_av_sample=True,
|
415 |
-
num_gpus=1,
|
416 |
-
seed=seed,
|
417 |
-
)
|
418 |
|
419 |
-
|
420 |
-
|
|
|
|
|
421 |
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
videos, prompts = inference(args, control_inputs, chunking)
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
log.info(f"Time taken: {end_time - start_time} s")
|
430 |
|
431 |
-
|
432 |
-
stop_watcher()
|
433 |
|
434 |
-
|
|
|
435 |
|
436 |
-
|
437 |
-
return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
|
|
|
32 |
from gpu_info import stop_watcher, watch_gpu_memory
|
33 |
|
34 |
PWD = os.path.dirname(__file__)
|
|
|
35 |
LOG_DIR = os.path.join(PWD, "logs")
|
36 |
os.makedirs(LOG_DIR, exist_ok=True)
|
37 |
|
|
|
369 |
import gradio as gr
|
370 |
|
371 |
|
372 |
+
def generate_video_fun(checkpoints_path: str):
|
373 |
+
def generate_video(
|
374 |
+
rgb_video_path,
|
375 |
+
hdmap_video_input,
|
376 |
+
lidar_video_input,
|
377 |
+
prompt,
|
378 |
+
negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
|
379 |
+
seed=42,
|
380 |
+
randomize_seed=False,
|
381 |
+
chunking=None,
|
382 |
+
progress=gr.Progress(track_tqdm=True),
|
383 |
+
):
|
384 |
+
_dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
|
385 |
+
logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
|
386 |
+
log_handler = log.init_dev_loguru_file(logfile_path)
|
387 |
+
|
388 |
+
if randomize_seed:
|
389 |
+
actual_seed = random.randint(0, 1000000)
|
390 |
+
else:
|
391 |
+
actual_seed = seed
|
392 |
+
|
393 |
+
log.info(f"actual_seed: {actual_seed}")
|
394 |
+
|
395 |
+
if rgb_video_path is None or not os.path.isfile(rgb_video_path):
|
396 |
+
log.warning(f"File `{rgb_video_path}` does not exist")
|
397 |
+
rgb_video_path = ""
|
398 |
+
|
399 |
+
# add timer to calculate the generation time
|
400 |
+
start_time = time.time()
|
401 |
+
|
402 |
+
# parse generation configs
|
403 |
+
args, control_inputs = parse_arguments(
|
404 |
+
controlnet_specs_in={
|
405 |
+
"hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
|
406 |
+
"lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
|
407 |
+
},
|
408 |
+
input_video_path=rgb_video_path,
|
409 |
+
checkpoint_dir=checkpoints_path,
|
410 |
+
prompt=prompt,
|
411 |
+
negative_prompt=negative_prompt,
|
412 |
+
sigma_max=80,
|
413 |
+
offload_text_encoder_model=True,
|
414 |
+
is_av_sample=True,
|
415 |
+
num_gpus=1,
|
416 |
+
seed=seed,
|
417 |
+
)
|
418 |
|
419 |
+
# watch gpu memory
|
420 |
+
watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory (used, total): {x} (MiB)"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
+
# start inference
|
423 |
+
if chunking <= 0:
|
424 |
+
chunking = None
|
425 |
+
videos, prompts = inference(args, control_inputs, chunking)
|
426 |
|
427 |
+
# print the generation time
|
428 |
+
end_time = time.time()
|
429 |
+
log.info(f"Time taken: {end_time - start_time} s")
|
|
|
430 |
|
431 |
+
# stop the watcher
|
432 |
+
stop_watcher()
|
|
|
433 |
|
434 |
+
video = videos[0]
|
|
|
435 |
|
436 |
+
log.logger.remove(log_handler)
|
437 |
+
return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
|
438 |
|
439 |
+
return generate_video
|
|
main.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
from argparse import ArgumentParser
|
2 |
import os
|
3 |
import sys
|
4 |
-
|
5 |
|
6 |
if __name__ == "__main__":
|
7 |
parser = ArgumentParser()
|
8 |
-
parser.add_argument("--
|
|
|
9 |
parser.add_argument("--rgb", type=str, default=None, help="RGB video input")
|
10 |
parser.add_argument("--hdmap", type=str, required=True, help="HD Map video input")
|
11 |
parser.add_argument("--lidar", type=str, required=True, help="Lidar video input")
|
@@ -31,22 +31,28 @@ if __name__ == "__main__":
|
|
31 |
os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
32 |
os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
|
33 |
|
34 |
-
if
|
35 |
from test_environment import main as check_environment
|
36 |
|
37 |
if not check_environment():
|
38 |
sys.exit(1)
|
39 |
else:
|
40 |
-
print("Skipping environment
|
41 |
|
42 |
-
from
|
43 |
-
|
|
|
44 |
|
45 |
# download checkpoints
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
rgb_video_path=args.rgb,
|
51 |
hdmap_video_input=args.hdmap,
|
52 |
lidar_video_input=args.lidar,
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
from argparse import ArgumentParser
|
4 |
|
5 |
if __name__ == "__main__":
|
6 |
parser = ArgumentParser()
|
7 |
+
parser.add_argument("--check-environment", action="store_true", help="Whether to do the environment checking")
|
8 |
+
parser.add_argument("--download-checkpoints", action="store_true", help="Whether to download the checkpoints")
|
9 |
parser.add_argument("--rgb", type=str, default=None, help="RGB video input")
|
10 |
parser.add_argument("--hdmap", type=str, required=True, help="HD Map video input")
|
11 |
parser.add_argument("--lidar", type=str, required=True, help="Lidar video input")
|
|
|
31 |
os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
|
32 |
os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
|
33 |
|
34 |
+
if args.check_environment:
|
35 |
from test_environment import main as check_environment
|
36 |
|
37 |
if not check_environment():
|
38 |
sys.exit(1)
|
39 |
else:
|
40 |
+
print("Skipping the environment checking")
|
41 |
|
42 |
+
from helper import PWD
|
43 |
+
|
44 |
+
CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
45 |
|
46 |
# download checkpoints
|
47 |
+
if args.download_checkpoints:
|
48 |
+
from download_checkpoints import main as download_checkpoints
|
49 |
+
|
50 |
+
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
51 |
+
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
52 |
+
|
53 |
+
from helper import generate_video_fun
|
54 |
|
55 |
+
generate_video_fun(CHECKPOINTS_PATH)(
|
56 |
rgb_video_path=args.rgb,
|
57 |
hdmap_video_input=args.hdmap,
|
58 |
lidar_video_input=args.lidar,
|