Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| import torch | |
| from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| from transformers import AutoModelForImageSegmentation | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') | |
| from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline | |
| from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
| from mvadapter.utils import ( | |
| get_orthogonal_camera, | |
| get_plucker_embeds_from_cameras_ortho, | |
| make_image_grid, | |
| ) | |
| def prepare_pipeline( | |
| base_model, | |
| vae_model, | |
| unet_model, | |
| lora_model, | |
| adapter_path, | |
| scheduler, | |
| num_views, | |
| device, | |
| dtype, | |
| ): | |
| # Load vae and unet if provided | |
| pipe_kwargs = {} | |
| if vae_model is not None: | |
| pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) | |
| if unet_model is not None: | |
| pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) | |
| # Prepare pipeline | |
| pipe: MVAdapterI2MVSDXLPipeline | |
| pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) | |
| # Load scheduler if provided | |
| scheduler_class = None | |
| if scheduler == "ddpm": | |
| scheduler_class = DDPMScheduler | |
| elif scheduler == "lcm": | |
| scheduler_class = LCMScheduler | |
| pipe.scheduler = ShiftSNRScheduler.from_scheduler( | |
| pipe.scheduler, | |
| shift_mode="interpolated", | |
| shift_scale=8.0, | |
| scheduler_class=scheduler_class, | |
| ) | |
| pipe.init_custom_adapter(num_views=num_views) | |
| pipe.load_custom_adapter( | |
| adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors" | |
| ) | |
| pipe.to(device=device, dtype=dtype) | |
| pipe.cond_encoder.to(device=device, dtype=dtype) | |
| # load lora if provided | |
| if lora_model is not None: | |
| model_, name_ = lora_model.rsplit("/", 1) | |
| pipe.load_lora_weights(model_, weight_name=name_) | |
| # vae slicing for lower memory usage | |
| pipe.enable_vae_slicing() | |
| return pipe | |
| def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = None): | |
| """ | |
| Applies a pre-existing mask to an image to make the background transparent. | |
| Args: | |
| image (PIL.Image.Image): The input image. | |
| net: Pre-trained neural network (not used but kept for compatibility). | |
| transform: Image transformation object (not used but kept for compatibility). | |
| device: Device used for inference (not used but kept for compatibility). | |
| mask (PIL.Image.Image, optional): The mask to use. Should be the same size | |
| as the input image, with values between 0 and 255 (or 0-1). | |
| If None, will return image with no changes. | |
| Returns: | |
| PIL.Image.Image: The modified image with transparent background. | |
| """ | |
| if mask is None: | |
| return image | |
| image_size = image.size | |
| if mask.size != image_size: | |
| mask = mask.resize(image_size) # Resizing the mask if it is not the same size as image | |
| image.putalpha(mask) | |
| return image | |
| # def remove_bg(image, net, transform, device): | |
| # image_size = image.size | |
| # input_images = transform(image).unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # preds = net(input_images)[0].sigmoid().cpu() | |
| # #preds = net(input_images)[-1] if isinstance(net(input_images), list) else net(input_images) | |
| # pred = preds[0].squeeze() | |
| # pred_pil = transforms.ToPILImage()(pred) | |
| # mask = pred_pil.resize(image_size) | |
| # image.putalpha(mask) | |
| # return image | |
| # def remove_bg(image: Image.Image, net, transform, device): | |
| # """ | |
| # Applies a pre-existing mask to an image to make the background transparent. | |
| # Args: | |
| # image (PIL.Image.Image): The input image. | |
| # net: Pre-trained neural network (not used but kept for compatibility). | |
| # transform: Image transformation object (not used but kept for compatibility). | |
| # device: Device used for inference (not used but kept for compatibility). | |
| # Returns: | |
| # PIL.Image.Image: The modified image with transparent background. | |
| # """ | |
| # image_size = image.size | |
| # input_images = transform(image).unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # preds = net(input_images)[-1].sigmoid().cpu() | |
| # pred = preds[0].squeeze() | |
| # pred_pil = transforms.ToPILImage()(pred) | |
| # # Resize the mask to match the original image size | |
| # mask = pred_pil.resize(image_size, Image.LANCZOS) | |
| # # Create a new image with the same size and mode as the original | |
| # output_image = Image.new("RGBA", image_size) | |
| # # Apply the mask to the original image | |
| # image.putalpha(mask) | |
| # # Composite the original image with the mask | |
| # output_image.paste(image, (0, 0), image) | |
| # return output_image | |
| def remove_bg(image: Image.Image, net, transform, device, mask: np.ndarray = None): | |
| """ | |
| Applies a pre-existing mask to an image to make the background transparent. | |
| Args: | |
| image (PIL.Image.Image): The input image. | |
| net: Pre-trained neural network (not used but kept for compatibility). | |
| transform: Image transformation object (not used but kept for compatibility). | |
| device: Device used for inference (not used but kept for compatibility). | |
| mask (np.ndarray, optional): The mask to use. Should be the same size | |
| as the input image, with values between 0 and 255. | |
| If None, will return image with no changes. | |
| Returns: | |
| PIL.Image.Image: The modified image with transparent background. | |
| """ | |
| if mask is None: | |
| return image | |
| # Ensure the mask is in the correct format | |
| if mask.ndim == 2: # If mask is 2D (H, W) | |
| mask = mask.astype(np.uint8) # Ensure mask is uint8 | |
| mask = np.expand_dims(mask, axis=-1) # Add channel dimension | |
| # Convert the mask to PIL Image | |
| mask_pil = Image.fromarray(mask.squeeze(2) * 255) # Convert to binary mask | |
| # Resize the mask to match the original image size | |
| mask_pil = mask_pil.resize(image.size, Image.LANCZOS) | |
| # Create a new image with the same size and mode as the original | |
| output_image = Image.new("RGBA", image.size) | |
| # Apply the mask to the original image | |
| image.putalpha(mask_pil) | |
| # Composite the original image with the mask | |
| output_image.paste(image, (0, 0), image) | |
| return output_image | |
| # def preprocess_image(image: Image.Image, height, width): | |
| # alpha = image[..., 3] > 0 | |
| # # alpha = image | |
| # #if image.mode in ("RGBA", "LA"): | |
| # # image = np.array(image) | |
| # # alpha = image[..., 3] # Extract the alpha channel | |
| # #elif image.mode in ("RGB"): | |
| # # image = np.array(image) | |
| # # Create default alpha for non-alpha images | |
| # # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create | |
| # H, W = alpha.shape | |
| # # get the bounding box of alpha | |
| # y, x = np.where(alpha) | |
| # y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) | |
| # x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) | |
| # image_center = image[y0:y1, x0:x1] | |
| # # resize the longer side to H * 0.9 | |
| # H, W, _ = image_center.shape | |
| # if H > W: | |
| # W = int(W * (height * 0.9) / H) | |
| # H = int(height * 0.9) | |
| # else: | |
| # H = int(H * (width * 0.9) / W) | |
| # W = int(width * 0.9) | |
| # image_center = np.array(Image.fromarray(image_center).resize((W, H))) | |
| # # pad to H, W | |
| # start_h = (height - H) // 2 | |
| # start_w = (width - W) // 2 | |
| # image = np.zeros((height, width, 4), dtype=np.uint8) | |
| # image[start_h : start_h + H, start_w : start_w + W] = image_center | |
| # image = image.astype(np.float32) / 255.0 | |
| # image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 | |
| # image = (image * 255).clip(0, 255).astype(np.uint8) | |
| # image = Image.fromarray(image) | |
| # return image | |
| def preprocess_image(image: Image.Image, height, width): | |
| # Convert image to numpy array | |
| image_np = np.array(image) | |
| # Extract the alpha channel if present | |
| if image_np.shape[-1] == 4: | |
| alpha = image_np[..., 3] > 0 # Create a binary mask from the alpha channel | |
| else: | |
| alpha = np.ones(image_np[..., 0].shape, dtype=bool) # Default to all true for RGB images | |
| H, W = alpha.shape | |
| # Get the bounding box of the alpha | |
| y, x = np.where(alpha) | |
| y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) | |
| x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) | |
| image_center = image_np[y0:y1, x0:x1] | |
| # Resize the longer side to H * 0.9 | |
| H, W, _ = image_center.shape | |
| if H > W: | |
| W = int(W * (height * 0.9) / H) | |
| H = int(height * 0.9) | |
| else: | |
| H = int(H * (width * 0.9) / W) | |
| W = int(width * 0.9) | |
| image_center = np.array(Image.fromarray(image_center).resize((W, H))) | |
| # Pad to H, W | |
| start_h = (height - H) // 2 | |
| start_w = (width - W) // 2 | |
| padded_image = np.zeros((height, width, 4), dtype=np.uint8) | |
| padded_image[start_h:start_h + H, start_w:start_w + W] = image_center | |
| # Convert back to PIL Image | |
| return Image.fromarray(padded_image) | |
| def run_pipeline( | |
| pipe, | |
| num_views, | |
| text, | |
| image, | |
| height, | |
| width, | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| remove_bg_fn=None, | |
| reference_conditioning_scale=1.0, | |
| negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", | |
| lora_scale=1.0, | |
| device="cuda", | |
| ): | |
| # Prepare cameras | |
| cameras = get_orthogonal_camera( | |
| elevation_deg=[0, 0, 0, 0, 0, 0], | |
| distance=[1.8] * num_views, | |
| left=-0.55, | |
| right=0.55, | |
| bottom=-0.55, | |
| top=0.55, | |
| azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], | |
| device=device, | |
| ) | |
| plucker_embeds = get_plucker_embeds_from_cameras_ortho( | |
| cameras.c2w, [1.1] * num_views, width | |
| ) | |
| control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) | |
| # Prepare image | |
| # reference_image = Image.open(image) if isinstance(image, str) else image | |
| # if remove_bg_fn is not None: | |
| # reference_image = remove_bg_fn(reference_image) | |
| # reference_image = preprocess_image(reference_image, height, width) | |
| # elif reference_image.mode == "RGBA": | |
| # reference_image = preprocess_image(reference_image, height, width) | |
| reference_image = Image.open(image) if isinstance(image, str) else image | |
| logging.info(f"Initial reference_image mode: {reference_image.mode}") | |
| if remove_bg_fn is not None: | |
| logging.info("Using remove_bg_fn") | |
| reference_image = remove_bg_fn(reference_image) | |
| reference_image = preprocess_image(reference_image, height, width) | |
| elif reference_image.mode == "RGBA": | |
| logging.info("Image is RGBA, preprocessing directly") | |
| reference_image = preprocess_image(reference_image, height, width) | |
| logging.info(f"Final reference_image mode: {reference_image.mode}") | |
| pipe_kwargs = {} | |
| if seed != -1 and isinstance(seed, int): | |
| pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) | |
| images = pipe( | |
| text, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_views, | |
| control_image=control_images, | |
| control_conditioning_scale=1.0, | |
| reference_image=reference_image, | |
| reference_conditioning_scale=reference_conditioning_scale, | |
| negative_prompt=negative_prompt, | |
| cross_attention_kwargs={"scale": lora_scale}, | |
| **pipe_kwargs, | |
| ).images | |
| return images, reference_image | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # Models | |
| parser.add_argument( | |
| "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" | |
| ) | |
| parser.add_argument( | |
| "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" | |
| ) | |
| parser.add_argument("--unet_model", type=str, default=None) | |
| parser.add_argument("--scheduler", type=str, default=None) | |
| parser.add_argument("--lora_model", type=str, default=None) | |
| parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") | |
| parser.add_argument("--num_views", type=int, default=6) | |
| # Device | |
| parser.add_argument("--device", type=str, default="cuda") | |
| # Inference | |
| parser.add_argument("--image", type=str, required=True) | |
| parser.add_argument("--text", type=str, default="high quality") | |
| parser.add_argument("--num_inference_steps", type=int, default=50) | |
| parser.add_argument("--guidance_scale", type=float, default=3.0) | |
| parser.add_argument("--seed", type=int, default=-1) | |
| parser.add_argument("--lora_scale", type=float, default=1.0) | |
| parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) | |
| parser.add_argument( | |
| "--negative_prompt", | |
| type=str, | |
| default="watermark, ugly, deformed, noisy, blurry, low contrast", | |
| ) | |
| parser.add_argument("--output", type=str, default="output.png") | |
| # Extra | |
| parser.add_argument("--remove_bg", action="store_true", help="Remove background") | |
| args = parser.parse_args() | |
| pipe = prepare_pipeline( | |
| base_model=args.base_model, | |
| vae_model=args.vae_model, | |
| unet_model=args.unet_model, | |
| lora_model=args.lora_model, | |
| adapter_path=args.adapter_path, | |
| scheduler=args.scheduler, | |
| num_views=args.num_views, | |
| device=args.device, | |
| dtype=torch.float16, | |
| ) | |
| if args.remove_bg: | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| birefnet.to(args.device) | |
| transform_image = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) | |
| else: | |
| remove_bg_fn = None | |
| images, reference_image = run_pipeline( | |
| pipe, | |
| num_views=args.num_views, | |
| text=args.text, | |
| image=args.image, | |
| height=768, | |
| width=768, | |
| num_inference_steps=args.num_inference_steps, | |
| guidance_scale=args.guidance_scale, | |
| seed=args.seed, | |
| lora_scale=args.lora_scale, | |
| reference_conditioning_scale=args.reference_conditioning_scale, | |
| negative_prompt=args.negative_prompt, | |
| device=args.device, | |
| remove_bg_fn=remove_bg_fn, | |
| ) | |
| make_image_grid(images, rows=1).save(args.output) | |
| reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") |