import argparse import logging import math import os import random import shutil from datetime import timedelta from pathlib import Path from typing import List, Optional, Tuple, Union from PIL import Image from diffusers.utils import ( check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available, load_image, ) from torchvision.transforms import ToPILImage import torch from pathlib import PosixPath from utils.utils import load_model_from_config,load_segmented_safe_weights,control_weight_files from models.cogvideox_transformer_3d_control import Control3DModel,Controled_CogVideoXTransformer3DModel from models.pipeline_cogvideox_image2video import Controled_CogVideoXImageToVideoPipeline,Controled_Memory_CogVideoXImageToVideoPipeline from models.global_local_memory_module import global_local_memory import diffusers from diffusers import ( AutoencoderKLCogVideoX, CogVideoXDPMScheduler, #CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, ) from lineart_extractor.annotator.lineart import LineartDetector from diffusers.image_processor import VaeImageProcessor from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import ( load_image, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module from torchvision.transforms.functional import center_crop, resize from torchvision.transforms import InterpolationMode import torchvision.transforms as TT import numpy as np from videoxl.model.builder import load_pretrained_model from videoxl.mm_utils import tokenizer_image_token, process_images,transform_input_id from videoxl.constants import IMAGE_TOKEN_INDEX,TOKEN_PERFRAME try: import decord except ImportError: raise ImportError( "The `decord` package is required for loading the video dataset. Install with `pip install decord`" ) decord.bridge.set_bridge("torch") from utils.autoreg_video_save_function import autoreg_video_save from decord import VideoReader, cpu from einops import rearrange import gc def _resize_for_rectangle_crop(arr,height,width,video_reshape_mode): image_size = height,width reshape_mode = video_reshape_mode if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: arr = resize( arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], interpolation=InterpolationMode.BICUBIC, ) else: arr = resize( arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], interpolation=InterpolationMode.BICUBIC, ) h, w = arr.shape[2], arr.shape[3] arr = arr.squeeze(0) delta_h = h - image_size[0] delta_w = w - image_size[1] if reshape_mode == "random" or reshape_mode == "none": top = np.random.randint(0, delta_h + 1) left = np.random.randint(0, delta_w + 1) elif reshape_mode == "center": top, left = delta_h // 2, delta_w // 2 else: raise NotImplementedError image_size = height, width arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) return arr def get_frame_length(frame_path): video_reader = decord.VideoReader(uri = frame_path.as_posix()) video_num_frames = len(video_reader) return video_num_frames def proccess_frame(frame_path,frames_start,frames_end): video_reader = decord.VideoReader(uri = frame_path.as_posix()) video_num_frames = len(video_reader) start_frame = frames_start end_frame = frames_end indices = list(range(start_frame, end_frame)) frames = video_reader.get_batch(indices) #frames = frames[start_frame: end_frame] selected_num_frames = frames.shape[0] print("selected_num_frames",selected_num_frames) # Choose first (4k + 1) frames as this is how many is required by the VAE remainder = (3 + (selected_num_frames % 4)) % 4 if remainder != 0: frames = frames[:-remainder] selected_num_frames = frames.shape[0] assert (selected_num_frames - 1) % 4 == 0 # Training transforms frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center") final_frames = frames.contiguous() return final_frames def proccess_image(frames): # Training transforms frames = frames.unsqueeze(0).permute(0, 3, 1, 2) # [F, C, H, W] frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center") final_frames = frames.contiguous() return final_frames def encode_sketch(video,pipe): video = video.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = pipe.vae.encode(video).latent_dist return latent_dist def process_sketch(sketch,linear_detector,pipe): sketch = sketch.to("cuda", dtype = torch.bfloat16) with torch.no_grad(): sketch = linear_detector(sketch,coarse=False) sketch=(sketch>0.78).float() sketch=1-sketch sketch=sketch.repeat(1,3,1,1) sketch = (sketch - 0.5) / 0.5 sketch=sketch.contiguous() sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] image = sketch[:, :, :1].clone() with torch.no_grad(): sketch = pipe.vae.encode(sketch).latent_dist sketches_first_frame=pipe.vae.encode(image).latent_dist sketch = sketch.sample() * pipe.vae.config.scaling_factor sketches_first_frame= sketches_first_frame.sample() * pipe.vae.config.scaling_factor sketch = sketch.permute(0, 2, 1, 3, 4) sketch = sketch.to(memory_format=torch.contiguous_format) sketches_first_frame = sketches_first_frame.permute(0, 2, 1, 3, 4) sketches_first_frame = sketches_first_frame.to(memory_format=torch.contiguous_format) return sketch,sketches_first_frame def process_sketch_image(sketch,linear_detector,pipe): sketch=torch.tensor(np.array(sketch)) sketch=proccess_image(sketch) sketch = sketch.to("cuda", dtype = torch.bfloat16) with torch.no_grad(): sketch = linear_detector(sketch,coarse=False) sketch=(sketch>0.78).float() sketch=1-sketch sketch=sketch.repeat(1,3,1,1) sketch = (sketch - 0.5) / 0.5 sketch=sketch.contiguous() sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] with torch.no_grad(): sketch = pipe.vae.encode(sketch).latent_dist sketch = sketch.sample() * pipe.vae.config.scaling_factor sketch = sketch.permute(0, 2, 1, 3, 4) sketch = sketch.to(memory_format=torch.contiguous_format) return sketch def log_validation( pipe, args, pipeline_args, device, use_glm=False, global_memory=None, local_memory=None, glm=None, past_latents=None, ): scheduler_args = {} idx = pipeline_args.pop("segment", None) video_key=pipeline_args.pop("video_key", None) clip_memory=False if idx==0 else True print("clip_memory",clip_memory) if "variance_type" in pipe.scheduler.config: variance_type = pipe.scheduler.config.variance_type if variance_type in ["learned", "learned_range"]: variance_type = "fixed_small" scheduler_args["variance_type"] = variance_type pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) pipe = pipe.to(device) generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None videos = [] os.makedirs(os.path.join(args.output_dir,video_key),exist_ok=True) video_tensor_path=os.path.join(args.output_dir,video_key) print(video_tensor_path,"video_tensor_path") with torch.no_grad(): for _ in range(args.num_validation_videos): frames_output, past_latents = pipe(**pipeline_args, generator=generator, output_type="pt", num_inference_steps=50,use_glm=use_glm, global_memory=global_memory, local_memory=local_memory, glm=glm, video_tensor_path=video_tensor_path, past_latents=past_latents[:,-4:-2] if (past_latents is not None) else None , clip_memory=clip_memory ) pt_images=frames_output.frames[0] #TODO here we can choose if we need the first frame or not pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) image_np = VaeImageProcessor.pt_to_numpy(pt_images) image_pil = VaeImageProcessor.numpy_to_pil(image_np) videos.append(image_pil) phase_name = f"inference_{idx}" video_filenames = [] for i, video in enumerate(videos): final_output_dir=os.path.join(args.output_dir,video_key) os.makedirs(final_output_dir,exist_ok=True) filename = os.path.join(final_output_dir, f"{phase_name}_video.mp4") export_to_video(video, filename, fps=args.fps) video_filenames.append(filename) autoreg_video_save(base_path=final_output_dir,suffix="inference_{}_video.mp4",num_videos=idx+1) return videos[0][65] def save_segments(total_frames,segment_length,overlap): start_frame = 0 segments = [] while start_frame + segment_length <= total_frames: end_frame = start_frame + segment_length segments.append((start_frame, end_frame)) start_frame = end_frame - overlap return segments def main(args): os.makedirs(args.output_dir,exist_ok=True) load_dtype=torch.bfloat16 transformer =Controled_CogVideoXTransformer3DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=load_dtype, ) control_config_path = "model_json/control_model_15_small.json" transformer_control_config = load_model_from_config(control_config_path) transformer_control = Control3DModel(**transformer_control_config) control_weight_files=[args.control_weght] transformer_control = load_segmented_safe_weights(transformer_control, control_weight_files) transformer_control = transformer_control.to(load_dtype) linear_detector=LineartDetector("cuda", dtype=torch.bfloat16) gen_kwargs = {"do_sample": True, "temperature": 1, "top_p": None, "num_beams": 1, "use_cache": True, "max_new_tokens": 2} # try: # video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="flash_attention_2") # except: video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="sdpa") video_model.config.beacon_ratio=[8] # you can delete this line to realize random compression of {2,4,8} ratio vllm_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nCan you describe the scene and color in anime?<|im_end|>\n<|im_start|>assistant\n" input_ids = tokenizer_image_token(vllm_prompt, video_tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(video_model.device) video_model.to( dtype=torch.bfloat16) glm=global_local_memory() glm_weight_files=[args.glm_weight] glm = load_segmented_safe_weights(glm,glm_weight_files) glm=glm.to(load_dtype) glm=glm.to("cuda") print("successful load glm") pipe = Controled_Memory_CogVideoXImageToVideoPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, transformer=transformer, transformer_control=transformer_control ).to("cuda") pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) del transformer,transformer_control gc.collect() torch.cuda.empty_cache() #pipe.enable_sequential_cpu_offload() if args.enable_slicing: pipe.vae.enable_slicing() if args.enable_tiling: pipe.vae.enable_tiling() #pipe = pipe.to("cuda") import json with open('test_json/long_testset.json',"r") as json_file: video_info=json.load(json_file) for video_key,value in video_info.items(): print('------------') print(video_key) validation_prompt=value['prompt'] video_path=PosixPath(value['video_path']) reference_image_path=str(value["reference_image"]) use_glm=False i=0 global_image=None frame_path = video_path video_num_frames=get_frame_length(frame_path) segments=save_segments(total_frames=video_num_frames,segment_length=args.max_num_frames,overlap=16) print(segments) '''''' past_latents=None for seg_idx,segment in enumerate(segments): print(seg_idx) print(segment) videos = proccess_frame(frame_path, frames_start=segment[0], frames_end=segment[1]) #print(segment) sketches,sketches_first_frame = process_sketch(videos,linear_detector,pipe) torch.cuda.empty_cache() print("sketches!!!",sketches.shape) validation_prompt = validation_prompt+" High quality, masterpiece, best quality, highres, ultra-detailed, fantastic." to_pil=ToPILImage() if global_image==None: print("------------------") print(reference_image_path) print('------------------') if reference_image_path != "0": image=Image.open(reference_image_path).convert("RGB") global_image=image sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe) else: image=to_pil(videos[0]).convert("RGB") global_image=image sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe) else: image=global_image pipeline_args = { "image": image, "prompt": validation_prompt, "guidance_scale": args.guidance_scale, "use_dynamic_cfg": args.use_dynamic_cfg, "height": args.height, "width": args.width, "sketches": sketches, "sketches_first_frame":sketches_first_frame, "num_frames":args.max_num_frames, "segment": seg_idx, "video_key":video_key } #load the video and process the video if use_glm: auto_path=os.path.join(args.output_dir,video_key,"autoreg_video_1.mp4") vr = VideoReader(auto_path, ctx=cpu(0)) total_frame_num = len(vr) if total_frame_num>650: max_frame=650 else: max_frame=total_frame_num uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frame, dtype=int) frame_idx = uniform_sampled_frames.tolist() frames = vr.get_batch(frame_idx).numpy() print(frames.shape) global_videos = clip_image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(video_model.device, dtype=torch.bfloat16) local_videos=global_videos[-20:,] beacon_skip_first = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1].item() with torch.inference_mode(): num_tokens=TOKEN_PERFRAME *global_videos.shape[0] beacon_skip_last = beacon_skip_first + num_tokens video_model.generate(input_ids, images=[global_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs) indices=[-9,-5,-1] global_memory=torch.cat([ torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0), torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)] ,dim=0).unsqueeze(0) video_model.clear_past_key_values() video_model.memory.reset() print(global_memory.shape) torch.cuda.empty_cache() num_tokens=TOKEN_PERFRAME *local_videos.shape[0] beacon_skip_last = beacon_skip_first + num_tokens video_model.generate(input_ids, images=[local_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs) indices=[-9,-5,-1] local_memory=torch.cat([ torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0), torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)] ,dim=0).unsqueeze(0) video_model.clear_past_key_values() video_model.memory.reset() del global_videos,local_videos torch.cuda.empty_cache() print(local_memory.shape) else: global_memory=None local_memory=None last_image=log_validation( pipe=pipe, args=args, pipeline_args=pipeline_args, device="cuda", use_glm=use_glm, global_memory=global_memory, local_memory=local_memory, glm=glm, past_latents=past_latents ) torch.cuda.empty_cache() use_glm=True def get_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") parser.add_argument( "--guidance_scale", type=float, default=6, help="The guidance scale to use while sampling validation videos.", ) # Model information parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--llm_model_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--control_weght", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--glm_weight", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--use_dynamic_cfg", action="store_true", default=False, help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument( "--num_validation_videos", type=int, default=1, help="Number of videos that should be generated during validation per `validation_prompt`.", ) # Training information parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--output_dir", type=str, default="cogvideox-i2v-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--height", type=int, default=480, help="All input videos are resized to this height.", ) parser.add_argument( "--width", type=int, default=720, help="All input videos are resized to this width.", ) parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") parser.add_argument( "--max_num_frames", type=int, default=81, help="All input videos will be truncated to these many frames." ) parser.add_argument( "--enable_slicing", action="store_true", default=False, help="Whether or not to use VAE slicing for saving memory.", ) parser.add_argument( "--enable_tiling", action="store_true", default=False, help="Whether or not to use VAE tiling for saving memory.", ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) return parser.parse_args() if __name__=="__main__": args = get_args() main(args)