Spaces:
Runtime error
Runtime error
"""some utils for api""" | |
import random | |
from typing import List | |
import numpy | |
from fastapi import Response | |
from fastapi.security import APIKeyHeader | |
from fastapi import HTTPException, Security | |
from fooocusapi.models.common.base import EnhanceCtrlNets, ImagePrompt | |
from modules import constants, flags | |
from modules import config | |
from modules.sdxl_styles import legal_style_names | |
from fooocusapi.args import args | |
from fooocusapi.utils.img_utils import read_input_image | |
from fooocusapi.utils.file_utils import ( | |
get_file_serve_url, | |
output_file_to_base64img, | |
output_file_to_bytesimg | |
) | |
from fooocusapi.utils.logger import logger | |
from fooocusapi.models.common.requests import ( | |
CommonRequest as Text2ImgRequest | |
) | |
from fooocusapi.models.common.response import ( | |
AsyncJobResponse, | |
AsyncJobStage, | |
GeneratedImageResult | |
) | |
from fooocusapi.models.requests_v1 import ( | |
ImageEnhanceRequest, ImgInpaintOrOutpaintRequest, | |
ImgPromptRequest, | |
ImgUpscaleOrVaryRequest | |
) | |
from fooocusapi.models.requests_v2 import ( | |
ImageEnhanceRequestJson, Text2ImgRequestWithPrompt, | |
ImgInpaintOrOutpaintRequestJson, | |
ImgUpscaleOrVaryRequestJson, | |
ImgPromptRequestJson | |
) | |
from fooocusapi.models.common.task import ( | |
ImageGenerationResult, | |
GenerationFinishReason | |
) | |
from fooocusapi.configs.default import ( | |
default_inpaint_engine_version, | |
default_sampler, | |
default_scheduler, | |
default_base_model_name, | |
default_refiner_model_name | |
) | |
from fooocusapi.parameters import ImageGenerationParams | |
from fooocusapi.task_queue import QueueTask | |
from modules.util import HWC3 | |
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False) | |
def refresh_seed(seed_string: int | str | None) -> int: | |
""" | |
Refresh and check seed number. | |
:params seed_string: seed, str or int. None means random | |
:return: seed number | |
""" | |
RANDOM_SEED = random.randint(constants.MIN_SEED, constants.MAX_SEED) | |
try: | |
seed_value = int(seed_string) | |
except ValueError: | |
return RANDOM_SEED | |
if seed_value < constants.MIN_SEED or seed_value > constants.MAX_SEED or seed_string == -1: | |
return RANDOM_SEED | |
return seed_value | |
def check_models_exist(file_name: str, model_type: str) -> str: | |
""" | |
Check if all models exist | |
""" | |
if file_name in (None, 'None'): | |
return 'None' | |
config.update_files() | |
if file_name not in (config.model_filenames + config.lora_filenames): | |
logger.std_warn(f"[Warning] Wrong {model_type} model input: {file_name}, using default") | |
if model_type == 'base': | |
return default_base_model_name | |
if model_type == 'refiner': | |
return default_refiner_model_name | |
return 'None' | |
return file_name | |
def api_key_auth(apikey: str = Security(api_key_header)): | |
""" | |
Check if the API key is valid, API key is not required if no API key is set | |
Args: | |
apikey: API key | |
returns: | |
None if API key is not set, otherwise raise HTTPException | |
""" | |
if args.apikey is None: | |
return # Skip API key check if no API key is set | |
if apikey != args.apikey: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams: | |
""" | |
Convert Request to ImageGenerationParams | |
Args: | |
req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest | |
returns: | |
ImageGenerationParams | |
""" | |
prompt = req.prompt | |
negative_prompt = req.negative_prompt | |
style_selections = [ | |
s for s in req.style_selections if s in legal_style_names] | |
performance_selection = req.performance_selection.value | |
aspect_ratios_selection = req.aspect_ratios_selection | |
image_number = req.image_number | |
image_seed = refresh_seed(req.image_seed) | |
sharpness = req.sharpness | |
guidance_scale = req.guidance_scale | |
base_model_name = check_models_exist(req.base_model_name, 'base') | |
refiner_model_name = check_models_exist(req.refiner_model_name, 'refiner') | |
refiner_switch = req.refiner_switch | |
loras = [(lora.enabled, check_models_exist(lora.model_name, 'lora'), lora.weight) for lora in req.loras] | |
uov_input_image = None | |
if not isinstance(req, Text2ImgRequestWithPrompt): | |
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)): | |
uov_input_image = read_input_image(req.input_image) | |
uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value | |
upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value | |
outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [ | |
s.value for s in req.outpaint_selections] | |
outpaint_distance_left = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left | |
outpaint_distance_right = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right | |
outpaint_distance_top = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top | |
outpaint_distance_bottom = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom | |
if refiner_model_name == '': | |
refiner_model_name = 'None' | |
inpaint_input_image = dict(image=None, mask=None) | |
inpaint_additional_prompt = None | |
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None: | |
inpaint_additional_prompt = req.inpaint_additional_prompt | |
input_image = read_input_image(req.input_image) | |
inpaint_image_size = input_image.shape[:2] | |
input_mask = HWC3(numpy.zeros(inpaint_image_size, dtype=numpy.uint8)) | |
if req.input_mask is not None: | |
input_mask = HWC3(read_input_image(req.input_mask)) | |
inpaint_input_image = { | |
'image': input_image, | |
'mask': input_mask | |
} | |
image_prompts = [] | |
if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)): | |
# Auto set mixing_image_prompt_and_inpaint to True | |
if len(req.image_prompts) > 0 and uov_input_image is not None: | |
print("[INFO] Mixing image prompt and vary upscale is set to True") | |
req.advanced_params.mixing_image_prompt_and_vary_upscale = True | |
elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None: | |
print("[INFO] Mixing image prompt and inpaint is set to True") | |
req.advanced_params.mixing_image_prompt_and_inpaint = True | |
for img_prompt in req.image_prompts: | |
if img_prompt.cn_img is not None: | |
cn_img = read_input_image(img_prompt.cn_img) | |
if img_prompt.cn_stop is None or img_prompt.cn_stop == 0: | |
img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0] | |
if img_prompt.cn_weight is None or img_prompt.cn_weight == 0: | |
img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1] | |
image_prompts.append( | |
(cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value)) | |
if len(image_prompts) < config.default_controlnet_image_count: | |
dp = (None, 0.5, 0.6, 'ImagePrompt') | |
image_prompts += [dp] * (config.default_controlnet_image_count - len(image_prompts)) | |
if isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)): | |
enhance_checkbox = True | |
enhance_input_image = read_input_image(req.enhance_input_image) | |
enhance_uov_method = req.enhance_uov_method | |
enhance_uov_processing_order = req.enhance_uov_processing_order | |
enhance_uov_prompt_type = req.enhance_uov_prompt_type | |
save_final_enhanced_image_only = True | |
else: | |
enhance_checkbox = False | |
enhance_input_image = None | |
enhance_uov_method = flags.disabled | |
enhance_uov_processing_order = "Before First Enhancement" | |
enhance_uov_prompt_type = "Original Prompts" | |
save_final_enhanced_image_only = False | |
if not isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)): | |
enhance_ctrlnets = [EnhanceCtrlNets()] * config.default_enhance_tabs | |
else: | |
enhance_ctrlnets = req.enhance_ctrlnets | |
advanced_params = None | |
if req.advanced_params is not None: | |
adp = req.advanced_params | |
if adp.refiner_swap_method not in ['joint', 'separate', 'vae']: | |
print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default") | |
adp.refiner_swap_method = 'joint' | |
if adp.sampler_name not in flags.sampler_list: | |
print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default") | |
adp.sampler_name = default_sampler | |
if adp.scheduler_name not in flags.scheduler_list: | |
print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default") | |
adp.scheduler_name = default_scheduler | |
if adp.inpaint_engine not in flags.inpaint_engine_versions: | |
print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default") | |
adp.inpaint_engine = default_inpaint_engine_version | |
advanced_params = adp | |
return ImageGenerationParams( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
style_selections=style_selections, | |
performance_selection=performance_selection, | |
aspect_ratios_selection=aspect_ratios_selection, | |
image_number=image_number, | |
image_seed=image_seed, | |
sharpness=sharpness, | |
guidance_scale=guidance_scale, | |
base_model_name=base_model_name, | |
refiner_model_name=refiner_model_name, | |
refiner_switch=refiner_switch, | |
loras=loras, | |
uov_input_image=uov_input_image, | |
uov_method=uov_method, | |
upscale_value=upscale_value, | |
outpaint_selections=outpaint_selections, | |
outpaint_distance_left=outpaint_distance_left, | |
outpaint_distance_right=outpaint_distance_right, | |
outpaint_distance_top=outpaint_distance_top, | |
outpaint_distance_bottom=outpaint_distance_bottom, | |
inpaint_input_image=inpaint_input_image, | |
inpaint_additional_prompt=inpaint_additional_prompt, | |
enhance_input_image=enhance_input_image, | |
enhance_checkbox=enhance_checkbox, | |
enhance_uov_method=enhance_uov_method, | |
enhance_uov_processing_order=enhance_uov_processing_order, | |
enhance_uov_prompt_type=enhance_uov_prompt_type, | |
save_final_enhanced_image_only=save_final_enhanced_image_only, | |
enhance_ctrlnets=enhance_ctrlnets, | |
read_wildcards_in_order=req.read_wildcards_in_order, | |
image_prompts=image_prompts, | |
advanced_params=advanced_params, | |
save_meta=req.save_meta, | |
meta_scheme=req.meta_scheme, | |
save_name=req.save_name, | |
save_extension=req.save_extension, | |
require_base64=req.require_base64, | |
) | |
def generate_async_output( | |
task: QueueTask, | |
require_step_preview: bool = False) -> AsyncJobResponse: | |
""" | |
Generate output for async job | |
Arguments: | |
task: QueueTask | |
require_step_preview: bool | |
Returns: | |
AsyncJobResponse | |
""" | |
job_stage = AsyncJobStage.running | |
job_result = None | |
if task.start_mills == 0: | |
job_stage = AsyncJobStage.waiting | |
if task.is_finished: | |
if task.finish_with_error: | |
job_stage = AsyncJobStage.error | |
elif task.task_result is not None: | |
job_stage = AsyncJobStage.success | |
job_result = generate_image_result_output(task.task_result, task.req_param.require_base64) | |
result = AsyncJobResponse( | |
job_id=task.job_id, | |
job_type=task.task_type, | |
job_stage=job_stage, | |
job_progress=task.finish_progress, | |
job_status=task.task_status, | |
job_step_preview=task.task_step_preview if require_step_preview else None, | |
job_result=job_result) | |
return result | |
def generate_streaming_output(results: List[ImageGenerationResult]) -> Response: | |
""" | |
Generate streaming output for image generation results. | |
Args: | |
results (List[ImageGenerationResult]): List of image generation results. | |
Returns: | |
Response: Streaming response object, bytes image. | |
""" | |
if len(results) == 0: | |
return Response(status_code=500) | |
result = results[0] | |
if result.finish_reason == GenerationFinishReason.queue_is_full: | |
return Response(status_code=409, content=result.finish_reason.value) | |
if result.finish_reason == GenerationFinishReason.user_cancel: | |
return Response(status_code=400, content=result.finish_reason.value) | |
if result.finish_reason == GenerationFinishReason.error: | |
return Response(status_code=500, content=result.finish_reason.value) | |
img_bytes = output_file_to_bytesimg(results[0].im) | |
return Response(img_bytes, media_type='image/png') | |
def generate_image_result_output( | |
results: List[ImageGenerationResult], | |
require_base64: bool) -> List[GeneratedImageResult]: | |
""" | |
Generate image result output | |
Arguments: | |
results: List[ImageGenerationResult] | |
require_base64: bool | |
Returns: | |
List[GeneratedImageResult] | |
""" | |
results = [ | |
GeneratedImageResult( | |
base64=output_file_to_base64img(item.im) if require_base64 else None, | |
url=get_file_serve_url(item.im), | |
seed=str(item.seed), | |
finish_reason=item.finish_reason | |
) for item in results | |
] | |
return results | |