Spaces:
Runtime error
Runtime error
File size: 14,052 Bytes
5fbd25d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
"""some utils for api"""
import random
from typing import List
import numpy
from fastapi import Response
from fastapi.security import APIKeyHeader
from fastapi import HTTPException, Security
from fooocusapi.models.common.base import EnhanceCtrlNets, ImagePrompt
from modules import constants, flags
from modules import config
from modules.sdxl_styles import legal_style_names
from fooocusapi.args import args
from fooocusapi.utils.img_utils import read_input_image
from fooocusapi.utils.file_utils import (
get_file_serve_url,
output_file_to_base64img,
output_file_to_bytesimg
)
from fooocusapi.utils.logger import logger
from fooocusapi.models.common.requests import (
CommonRequest as Text2ImgRequest
)
from fooocusapi.models.common.response import (
AsyncJobResponse,
AsyncJobStage,
GeneratedImageResult
)
from fooocusapi.models.requests_v1 import (
ImageEnhanceRequest, ImgInpaintOrOutpaintRequest,
ImgPromptRequest,
ImgUpscaleOrVaryRequest
)
from fooocusapi.models.requests_v2 import (
ImageEnhanceRequestJson, Text2ImgRequestWithPrompt,
ImgInpaintOrOutpaintRequestJson,
ImgUpscaleOrVaryRequestJson,
ImgPromptRequestJson
)
from fooocusapi.models.common.task import (
ImageGenerationResult,
GenerationFinishReason
)
from fooocusapi.configs.default import (
default_inpaint_engine_version,
default_sampler,
default_scheduler,
default_base_model_name,
default_refiner_model_name
)
from fooocusapi.parameters import ImageGenerationParams
from fooocusapi.task_queue import QueueTask
from modules.util import HWC3
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
def refresh_seed(seed_string: int | str | None) -> int:
"""
Refresh and check seed number.
:params seed_string: seed, str or int. None means random
:return: seed number
"""
RANDOM_SEED = random.randint(constants.MIN_SEED, constants.MAX_SEED)
try:
seed_value = int(seed_string)
except ValueError:
return RANDOM_SEED
if seed_value < constants.MIN_SEED or seed_value > constants.MAX_SEED or seed_string == -1:
return RANDOM_SEED
return seed_value
def check_models_exist(file_name: str, model_type: str) -> str:
"""
Check if all models exist
"""
if file_name in (None, 'None'):
return 'None'
config.update_files()
if file_name not in (config.model_filenames + config.lora_filenames):
logger.std_warn(f"[Warning] Wrong {model_type} model input: {file_name}, using default")
if model_type == 'base':
return default_base_model_name
if model_type == 'refiner':
return default_refiner_model_name
return 'None'
return file_name
def api_key_auth(apikey: str = Security(api_key_header)):
"""
Check if the API key is valid, API key is not required if no API key is set
Args:
apikey: API key
returns:
None if API key is not set, otherwise raise HTTPException
"""
if args.apikey is None:
return # Skip API key check if no API key is set
if apikey != args.apikey:
raise HTTPException(status_code=403, detail="Forbidden")
def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
"""
Convert Request to ImageGenerationParams
Args:
req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
returns:
ImageGenerationParams
"""
prompt = req.prompt
negative_prompt = req.negative_prompt
style_selections = [
s for s in req.style_selections if s in legal_style_names]
performance_selection = req.performance_selection.value
aspect_ratios_selection = req.aspect_ratios_selection
image_number = req.image_number
image_seed = refresh_seed(req.image_seed)
sharpness = req.sharpness
guidance_scale = req.guidance_scale
base_model_name = check_models_exist(req.base_model_name, 'base')
refiner_model_name = check_models_exist(req.refiner_model_name, 'refiner')
refiner_switch = req.refiner_switch
loras = [(lora.enabled, check_models_exist(lora.model_name, 'lora'), lora.weight) for lora in req.loras]
uov_input_image = None
if not isinstance(req, Text2ImgRequestWithPrompt):
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
uov_input_image = read_input_image(req.input_image)
uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
s.value for s in req.outpaint_selections]
outpaint_distance_left = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
outpaint_distance_right = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
outpaint_distance_top = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
outpaint_distance_bottom = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
if refiner_model_name == '':
refiner_model_name = 'None'
inpaint_input_image = dict(image=None, mask=None)
inpaint_additional_prompt = None
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
inpaint_additional_prompt = req.inpaint_additional_prompt
input_image = read_input_image(req.input_image)
inpaint_image_size = input_image.shape[:2]
input_mask = HWC3(numpy.zeros(inpaint_image_size, dtype=numpy.uint8))
if req.input_mask is not None:
input_mask = HWC3(read_input_image(req.input_mask))
inpaint_input_image = {
'image': input_image,
'mask': input_mask
}
image_prompts = []
if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
# Auto set mixing_image_prompt_and_inpaint to True
if len(req.image_prompts) > 0 and uov_input_image is not None:
print("[INFO] Mixing image prompt and vary upscale is set to True")
req.advanced_params.mixing_image_prompt_and_vary_upscale = True
elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
print("[INFO] Mixing image prompt and inpaint is set to True")
req.advanced_params.mixing_image_prompt_and_inpaint = True
for img_prompt in req.image_prompts:
if img_prompt.cn_img is not None:
cn_img = read_input_image(img_prompt.cn_img)
if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
image_prompts.append(
(cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
if len(image_prompts) < config.default_controlnet_image_count:
dp = (None, 0.5, 0.6, 'ImagePrompt')
image_prompts += [dp] * (config.default_controlnet_image_count - len(image_prompts))
if isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)):
enhance_checkbox = True
enhance_input_image = read_input_image(req.enhance_input_image)
enhance_uov_method = req.enhance_uov_method
enhance_uov_processing_order = req.enhance_uov_processing_order
enhance_uov_prompt_type = req.enhance_uov_prompt_type
save_final_enhanced_image_only = True
else:
enhance_checkbox = False
enhance_input_image = None
enhance_uov_method = flags.disabled
enhance_uov_processing_order = "Before First Enhancement"
enhance_uov_prompt_type = "Original Prompts"
save_final_enhanced_image_only = False
if not isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)):
enhance_ctrlnets = [EnhanceCtrlNets()] * config.default_enhance_tabs
else:
enhance_ctrlnets = req.enhance_ctrlnets
advanced_params = None
if req.advanced_params is not None:
adp = req.advanced_params
if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
adp.refiner_swap_method = 'joint'
if adp.sampler_name not in flags.sampler_list:
print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
adp.sampler_name = default_sampler
if adp.scheduler_name not in flags.scheduler_list:
print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
adp.scheduler_name = default_scheduler
if adp.inpaint_engine not in flags.inpaint_engine_versions:
print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
adp.inpaint_engine = default_inpaint_engine_version
advanced_params = adp
return ImageGenerationParams(
prompt=prompt,
negative_prompt=negative_prompt,
style_selections=style_selections,
performance_selection=performance_selection,
aspect_ratios_selection=aspect_ratios_selection,
image_number=image_number,
image_seed=image_seed,
sharpness=sharpness,
guidance_scale=guidance_scale,
base_model_name=base_model_name,
refiner_model_name=refiner_model_name,
refiner_switch=refiner_switch,
loras=loras,
uov_input_image=uov_input_image,
uov_method=uov_method,
upscale_value=upscale_value,
outpaint_selections=outpaint_selections,
outpaint_distance_left=outpaint_distance_left,
outpaint_distance_right=outpaint_distance_right,
outpaint_distance_top=outpaint_distance_top,
outpaint_distance_bottom=outpaint_distance_bottom,
inpaint_input_image=inpaint_input_image,
inpaint_additional_prompt=inpaint_additional_prompt,
enhance_input_image=enhance_input_image,
enhance_checkbox=enhance_checkbox,
enhance_uov_method=enhance_uov_method,
enhance_uov_processing_order=enhance_uov_processing_order,
enhance_uov_prompt_type=enhance_uov_prompt_type,
save_final_enhanced_image_only=save_final_enhanced_image_only,
enhance_ctrlnets=enhance_ctrlnets,
read_wildcards_in_order=req.read_wildcards_in_order,
image_prompts=image_prompts,
advanced_params=advanced_params,
save_meta=req.save_meta,
meta_scheme=req.meta_scheme,
save_name=req.save_name,
save_extension=req.save_extension,
require_base64=req.require_base64,
)
def generate_async_output(
task: QueueTask,
require_step_preview: bool = False) -> AsyncJobResponse:
"""
Generate output for async job
Arguments:
task: QueueTask
require_step_preview: bool
Returns:
AsyncJobResponse
"""
job_stage = AsyncJobStage.running
job_result = None
if task.start_mills == 0:
job_stage = AsyncJobStage.waiting
if task.is_finished:
if task.finish_with_error:
job_stage = AsyncJobStage.error
elif task.task_result is not None:
job_stage = AsyncJobStage.success
job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
result = AsyncJobResponse(
job_id=task.job_id,
job_type=task.task_type,
job_stage=job_stage,
job_progress=task.finish_progress,
job_status=task.task_status,
job_step_preview=task.task_step_preview if require_step_preview else None,
job_result=job_result)
return result
def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
"""
Generate streaming output for image generation results.
Args:
results (List[ImageGenerationResult]): List of image generation results.
Returns:
Response: Streaming response object, bytes image.
"""
if len(results) == 0:
return Response(status_code=500)
result = results[0]
if result.finish_reason == GenerationFinishReason.queue_is_full:
return Response(status_code=409, content=result.finish_reason.value)
if result.finish_reason == GenerationFinishReason.user_cancel:
return Response(status_code=400, content=result.finish_reason.value)
if result.finish_reason == GenerationFinishReason.error:
return Response(status_code=500, content=result.finish_reason.value)
img_bytes = output_file_to_bytesimg(results[0].im)
return Response(img_bytes, media_type='image/png')
def generate_image_result_output(
results: List[ImageGenerationResult],
require_base64: bool) -> List[GeneratedImageResult]:
"""
Generate image result output
Arguments:
results: List[ImageGenerationResult]
require_base64: bool
Returns:
List[GeneratedImageResult]
"""
results = [
GeneratedImageResult(
base64=output_file_to_base64img(item.im) if require_base64 else None,
url=get_file_serve_url(item.im),
seed=str(item.seed),
finish_reason=item.finish_reason
) for item in results
]
return results
|