Spaces:
Configuration error
Configuration error
File size: 6,738 Bytes
4c35f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import dataclasses
from typing import Literal
from dotenv import load_dotenv
load_dotenv()
from accelerate import Accelerator
from transformers import HfArgumentParser
from PIL import Image
import json
import itertools
import torch
from uso.flux.pipeline import USOPipeline, preprocess_ref
from transformers import SiglipVisionModel, SiglipImageProcessor
from tqdm import tqdm
def horizontal_concat(images):
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
new_im = Image.new("RGB", (total_width, max_height))
x_offset = 0
for img in images:
new_im.paste(img, (x_offset, 0))
x_offset += img.size[0]
return new_im
@dataclasses.dataclass
class InferenceArgs:
prompt: str | None = None
image_paths: list[str] | None = None
eval_json_path: str | None = None
offload: bool = False
num_images_per_prompt: int = 1
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
width: int = 1024
height: int = 1024
num_steps: int = 25
guidance: float = 4
seed: int = 3407
save_path: str = "output/inference"
only_lora: bool = True
concat_refs: bool = False
lora_rank: int = 128
pe: Literal["d", "h", "w", "o"] = "d"
content_ref: int = 512
ckpt_path: str | None = None
use_siglip: bool = True
instruct_edit: bool = False
hf_download: bool = False # set to false, we must not auto download the weights (゜-゜)
def main(args: InferenceArgs):
accelerator = Accelerator()
# init SigLIP model
siglip_processor = None
siglip_model = None
if args.use_siglip:
# ⚠️ Weights now load from local paths via .env instead of downloading
siglip_path = os.getenv("SIGLIP_PATH", "google/siglip-so400m-patch14-384")
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
siglip_model.eval()
siglip_model.to(accelerator.device)
print("SigLIP model loaded successfully")
pipeline = USOPipeline(
args.model_type,
accelerator.device,
args.offload,
only_lora=args.only_lora,
lora_rank=args.lora_rank,
hf_download=args.hf_download,
)
if args.use_siglip and siglip_model is not None:
pipeline.model.vision_encoder = siglip_model
assert (
args.prompt is not None or args.eval_json_path is not None
), "Please provide either prompt or eval_json_path"
if args.eval_json_path is not None:
with open(args.eval_json_path, "rt") as f:
data_dicts = json.load(f)
data_root = os.path.dirname(args.eval_json_path)
else:
data_root = ""
data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}]
print(
f"process: {accelerator.num_processes}/{accelerator.process_index}, \
process images: {len(data_dicts)}/{len(data_dicts[accelerator.process_index::accelerator.num_processes])}"
)
data_dicts = data_dicts[accelerator.process_index :: accelerator.num_processes]
accelerator.wait_for_everyone()
local_task_count = len(data_dicts) * args.num_images_per_prompt
if accelerator.is_main_process:
progress_bar = tqdm(total=local_task_count, desc="Generating Images")
for (i, data_dict), j in itertools.product(
enumerate(data_dicts), range(args.num_images_per_prompt)
):
ref_imgs = []
for _, img_path in enumerate(data_dict["image_paths"]):
if img_path != "":
img = Image.open(os.path.join(data_root, img_path)).convert("RGB")
ref_imgs.append(img)
else:
ref_imgs.append(None)
siglip_inputs = None
if args.use_siglip and siglip_processor is not None:
with torch.no_grad():
siglip_inputs = [
siglip_processor(img, return_tensors="pt").to(pipeline.device)
for img in ref_imgs[1:] if isinstance(img, Image.Image)
]
ref_imgs_pil = [
preprocess_ref(img, args.content_ref) for img in ref_imgs[:1] if isinstance(img, Image.Image)
]
if args.instruct_edit:
args.width, args.height = ref_imgs_pil[0].size
args.width, args.height = args.width * (1024 / args.content_ref), args.height * (1024 / args.content_ref)
image_gen = pipeline(
prompt=data_dict["prompt"],
width=args.width,
height=args.height,
guidance=args.guidance,
num_steps=args.num_steps,
seed=args.seed + j,
ref_imgs=ref_imgs_pil,
pe=args.pe,
siglip_inputs=siglip_inputs,
)
if args.concat_refs:
image_gen = horizontal_concat([image_gen, *ref_imgs])
if "save_dir" in data_dict:
config_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.json")
image_save_path = os.path.join(args.save_path, data_dict["save_dir"] + f"_{j}.png")
else:
os.makedirs(args.save_path, exist_ok=True)
config_save_path = os.path.join(args.save_path, f"{i}_{j}.json")
image_save_path = os.path.join(args.save_path, f"{i}_{j}.png")
# save config and image
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_gen.save(image_save_path)
# ensure the prompt and image_paths are saved in the config file
args.prompt = data_dict["prompt"]
args.image_paths = data_dict["image_paths"]
args_dict = vars(args)
with open(config_save_path, "w") as f:
json.dump(args_dict, f, indent=4)
if accelerator.is_main_process:
progress_bar.update(1)
if accelerator.is_main_process:
progress_bar.close()
if __name__ == "__main__":
parser = HfArgumentParser([InferenceArgs])
args = parser.parse_args_into_dataclasses()[0]
main(args)
|