harry900000 commited on
Commit
dcc4583
·
1 Parent(s): 7e8823d

move assignment of CHECKPOINTS_PATH to `app.py`

Browse files
Files changed (4) hide show
  1. app.py +5 -3
  2. download_checkpoints.py +9 -0
  3. helper.py +61 -59
  4. main.py +16 -10
app.py CHANGED
@@ -31,15 +31,17 @@ if not check_environment():
31
 
32
 
33
  from download_checkpoints import main as download_checkpoints
34
- from helper import CHECKPOINTS_PATH, generate_video
 
35
 
36
  # download checkpoints
37
  os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
38
  download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
39
 
40
-
41
  import gradio as gr
42
 
 
 
43
  # Define the Gradio Blocks interface
44
  with gr.Blocks() as demo:
45
  gr.Markdown(
@@ -81,7 +83,7 @@ with gr.Blocks() as demo:
81
  output_file = gr.File(label="Download Results")
82
 
83
  generate_button.click(
84
- fn=generate_video,
85
  inputs=[
86
  rgb_video_input,
87
  hdmap_input,
 
31
 
32
 
33
  from download_checkpoints import main as download_checkpoints
34
+
35
+ CHECKPOINTS_PATH = "/data/checkpoints"
36
 
37
  # download checkpoints
38
  os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
39
  download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
40
 
 
41
  import gradio as gr
42
 
43
+ from helper import generate_video_fun
44
+
45
  # Define the Gradio Blocks interface
46
  with gr.Blocks() as demo:
47
  gr.Markdown(
 
83
  output_file = gr.File(label="Download Results")
84
 
85
  generate_button.click(
86
+ fn=generate_video_fun(CHECKPOINTS_PATH),
87
  inputs=[
88
  rgb_video_input,
89
  hdmap_input,
download_checkpoints.py CHANGED
@@ -131,3 +131,12 @@ def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkp
131
  # Download each checkpoint
132
  for checkpoint in checkpoint_vars:
133
  download_checkpoint(checkpoint, output_dir)
 
 
 
 
 
 
 
 
 
 
131
  # Download each checkpoint
132
  for checkpoint in checkpoint_vars:
133
  download_checkpoint(checkpoint, output_dir)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ import sys
138
+
139
+ PWD = os.path.dirname(__file__)
140
+ CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
141
+ os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
142
+ main(hf_token=sys.argv[1], output_dir=CHECKPOINTS_PATH, model="all")
helper.py CHANGED
@@ -32,7 +32,6 @@ from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
32
  from gpu_info import stop_watcher, watch_gpu_memory
33
 
34
  PWD = os.path.dirname(__file__)
35
- CHECKPOINTS_PATH = "/data/checkpoints"
36
  LOG_DIR = os.path.join(PWD, "logs")
37
  os.makedirs(LOG_DIR, exist_ok=True)
38
 
@@ -370,68 +369,71 @@ def create_zip_for_download(filename, files_to_zip):
370
  import gradio as gr
371
 
372
 
373
- def generate_video(
374
- rgb_video_path,
375
- hdmap_video_input,
376
- lidar_video_input,
377
- prompt,
378
- negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
379
- seed=42,
380
- randomize_seed=False,
381
- chunking=None,
382
- progress=gr.Progress(track_tqdm=True),
383
- ):
384
- _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
385
- logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
386
- log_handler = log.init_dev_loguru_file(logfile_path)
387
-
388
- if randomize_seed:
389
- actual_seed = random.randint(0, 1000000)
390
- else:
391
- actual_seed = seed
392
-
393
- log.info(f"actual_seed: {actual_seed}")
394
-
395
- if rgb_video_path is None or not os.path.isfile(rgb_video_path):
396
- log.warning(f"File `{rgb_video_path}` does not exist")
397
- rgb_video_path = ""
398
-
399
- # add timer to calculate the generation time
400
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
- # parse generation configs
403
- args, control_inputs = parse_arguments(
404
- controlnet_specs_in={
405
- "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
406
- "lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
407
- },
408
- input_video_path=rgb_video_path,
409
- checkpoint_dir=CHECKPOINTS_PATH,
410
- prompt=prompt,
411
- negative_prompt=negative_prompt,
412
- sigma_max=80,
413
- offload_text_encoder_model=True,
414
- is_av_sample=True,
415
- num_gpus=1,
416
- seed=seed,
417
- )
418
 
419
- # watch gpu memory
420
- watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory (used, total): {x} (MiB)"))
 
 
421
 
422
- # start inference
423
- if chunking <= 0:
424
- chunking = None
425
- videos, prompts = inference(args, control_inputs, chunking)
426
 
427
- # print the generation time
428
- end_time = time.time()
429
- log.info(f"Time taken: {end_time - start_time} s")
430
 
431
- # stop the watcher
432
- stop_watcher()
433
 
434
- video = videos[0]
 
435
 
436
- log.logger.remove(log_handler)
437
- return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
 
32
  from gpu_info import stop_watcher, watch_gpu_memory
33
 
34
  PWD = os.path.dirname(__file__)
 
35
  LOG_DIR = os.path.join(PWD, "logs")
36
  os.makedirs(LOG_DIR, exist_ok=True)
37
 
 
369
  import gradio as gr
370
 
371
 
372
+ def generate_video_fun(checkpoints_path: str):
373
+ def generate_video(
374
+ rgb_video_path,
375
+ hdmap_video_input,
376
+ lidar_video_input,
377
+ prompt,
378
+ negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
379
+ seed=42,
380
+ randomize_seed=False,
381
+ chunking=None,
382
+ progress=gr.Progress(track_tqdm=True),
383
+ ):
384
+ _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
385
+ logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
386
+ log_handler = log.init_dev_loguru_file(logfile_path)
387
+
388
+ if randomize_seed:
389
+ actual_seed = random.randint(0, 1000000)
390
+ else:
391
+ actual_seed = seed
392
+
393
+ log.info(f"actual_seed: {actual_seed}")
394
+
395
+ if rgb_video_path is None or not os.path.isfile(rgb_video_path):
396
+ log.warning(f"File `{rgb_video_path}` does not exist")
397
+ rgb_video_path = ""
398
+
399
+ # add timer to calculate the generation time
400
+ start_time = time.time()
401
+
402
+ # parse generation configs
403
+ args, control_inputs = parse_arguments(
404
+ controlnet_specs_in={
405
+ "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
406
+ "lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
407
+ },
408
+ input_video_path=rgb_video_path,
409
+ checkpoint_dir=checkpoints_path,
410
+ prompt=prompt,
411
+ negative_prompt=negative_prompt,
412
+ sigma_max=80,
413
+ offload_text_encoder_model=True,
414
+ is_av_sample=True,
415
+ num_gpus=1,
416
+ seed=seed,
417
+ )
418
 
419
+ # watch gpu memory
420
+ watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory (used, total): {x} (MiB)"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
+ # start inference
423
+ if chunking <= 0:
424
+ chunking = None
425
+ videos, prompts = inference(args, control_inputs, chunking)
426
 
427
+ # print the generation time
428
+ end_time = time.time()
429
+ log.info(f"Time taken: {end_time - start_time} s")
 
430
 
431
+ # stop the watcher
432
+ stop_watcher()
 
433
 
434
+ video = videos[0]
 
435
 
436
+ log.logger.remove(log_handler)
437
+ return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
438
 
439
+ return generate_video
 
main.py CHANGED
@@ -1,11 +1,11 @@
1
- from argparse import ArgumentParser
2
  import os
3
  import sys
4
-
5
 
6
  if __name__ == "__main__":
7
  parser = ArgumentParser()
8
- parser.add_argument("--skip-check-environment", action="store_true", help="Whether to skip environment checking")
 
9
  parser.add_argument("--rgb", type=str, default=None, help="RGB video input")
10
  parser.add_argument("--hdmap", type=str, required=True, help="HD Map video input")
11
  parser.add_argument("--lidar", type=str, required=True, help="Lidar video input")
@@ -31,22 +31,28 @@ if __name__ == "__main__":
31
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
32
  os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
33
 
34
- if not args.skip_check_environment:
35
  from test_environment import main as check_environment
36
 
37
  if not check_environment():
38
  sys.exit(1)
39
  else:
40
- print("Skipping environment check")
41
 
42
- from download_checkpoints import main as download_checkpoints
43
- from helper import CHECKPOINTS_PATH, generate_video
 
44
 
45
  # download checkpoints
46
- os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
47
- download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
 
 
 
 
 
48
 
49
- generate_video(
50
  rgb_video_path=args.rgb,
51
  hdmap_video_input=args.hdmap,
52
  lidar_video_input=args.lidar,
 
 
1
  import os
2
  import sys
3
+ from argparse import ArgumentParser
4
 
5
  if __name__ == "__main__":
6
  parser = ArgumentParser()
7
+ parser.add_argument("--check-environment", action="store_true", help="Whether to do the environment checking")
8
+ parser.add_argument("--download-checkpoints", action="store_true", help="Whether to download the checkpoints")
9
  parser.add_argument("--rgb", type=str, default=None, help="RGB video input")
10
  parser.add_argument("--hdmap", type=str, required=True, help="HD Map video input")
11
  parser.add_argument("--lidar", type=str, required=True, help="Lidar video input")
 
31
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
32
  os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
33
 
34
+ if args.check_environment:
35
  from test_environment import main as check_environment
36
 
37
  if not check_environment():
38
  sys.exit(1)
39
  else:
40
+ print("Skipping the environment checking")
41
 
42
+ from helper import PWD
43
+
44
+ CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
45
 
46
  # download checkpoints
47
+ if args.download_checkpoints:
48
+ from download_checkpoints import main as download_checkpoints
49
+
50
+ os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
51
+ download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
52
+
53
+ from helper import generate_video_fun
54
 
55
+ generate_video_fun(CHECKPOINTS_PATH)(
56
  rgb_video_path=args.rgb,
57
  hdmap_video_input=args.hdmap,
58
  lidar_video_input=args.lidar,