Sanket17's picture
added all files
5fbd25d
raw
history blame
14.1 kB
"""some utils for api"""
import random
from typing import List
import numpy
from fastapi import Response
from fastapi.security import APIKeyHeader
from fastapi import HTTPException, Security
from fooocusapi.models.common.base import EnhanceCtrlNets, ImagePrompt
from modules import constants, flags
from modules import config
from modules.sdxl_styles import legal_style_names
from fooocusapi.args import args
from fooocusapi.utils.img_utils import read_input_image
from fooocusapi.utils.file_utils import (
get_file_serve_url,
output_file_to_base64img,
output_file_to_bytesimg
)
from fooocusapi.utils.logger import logger
from fooocusapi.models.common.requests import (
CommonRequest as Text2ImgRequest
)
from fooocusapi.models.common.response import (
AsyncJobResponse,
AsyncJobStage,
GeneratedImageResult
)
from fooocusapi.models.requests_v1 import (
ImageEnhanceRequest, ImgInpaintOrOutpaintRequest,
ImgPromptRequest,
ImgUpscaleOrVaryRequest
)
from fooocusapi.models.requests_v2 import (
ImageEnhanceRequestJson, Text2ImgRequestWithPrompt,
ImgInpaintOrOutpaintRequestJson,
ImgUpscaleOrVaryRequestJson,
ImgPromptRequestJson
)
from fooocusapi.models.common.task import (
ImageGenerationResult,
GenerationFinishReason
)
from fooocusapi.configs.default import (
default_inpaint_engine_version,
default_sampler,
default_scheduler,
default_base_model_name,
default_refiner_model_name
)
from fooocusapi.parameters import ImageGenerationParams
from fooocusapi.task_queue import QueueTask
from modules.util import HWC3
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
def refresh_seed(seed_string: int | str | None) -> int:
"""
Refresh and check seed number.
:params seed_string: seed, str or int. None means random
:return: seed number
"""
RANDOM_SEED = random.randint(constants.MIN_SEED, constants.MAX_SEED)
try:
seed_value = int(seed_string)
except ValueError:
return RANDOM_SEED
if seed_value < constants.MIN_SEED or seed_value > constants.MAX_SEED or seed_string == -1:
return RANDOM_SEED
return seed_value
def check_models_exist(file_name: str, model_type: str) -> str:
"""
Check if all models exist
"""
if file_name in (None, 'None'):
return 'None'
config.update_files()
if file_name not in (config.model_filenames + config.lora_filenames):
logger.std_warn(f"[Warning] Wrong {model_type} model input: {file_name}, using default")
if model_type == 'base':
return default_base_model_name
if model_type == 'refiner':
return default_refiner_model_name
return 'None'
return file_name
def api_key_auth(apikey: str = Security(api_key_header)):
"""
Check if the API key is valid, API key is not required if no API key is set
Args:
apikey: API key
returns:
None if API key is not set, otherwise raise HTTPException
"""
if args.apikey is None:
return # Skip API key check if no API key is set
if apikey != args.apikey:
raise HTTPException(status_code=403, detail="Forbidden")
def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
"""
Convert Request to ImageGenerationParams
Args:
req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
returns:
ImageGenerationParams
"""
prompt = req.prompt
negative_prompt = req.negative_prompt
style_selections = [
s for s in req.style_selections if s in legal_style_names]
performance_selection = req.performance_selection.value
aspect_ratios_selection = req.aspect_ratios_selection
image_number = req.image_number
image_seed = refresh_seed(req.image_seed)
sharpness = req.sharpness
guidance_scale = req.guidance_scale
base_model_name = check_models_exist(req.base_model_name, 'base')
refiner_model_name = check_models_exist(req.refiner_model_name, 'refiner')
refiner_switch = req.refiner_switch
loras = [(lora.enabled, check_models_exist(lora.model_name, 'lora'), lora.weight) for lora in req.loras]
uov_input_image = None
if not isinstance(req, Text2ImgRequestWithPrompt):
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
uov_input_image = read_input_image(req.input_image)
uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
s.value for s in req.outpaint_selections]
outpaint_distance_left = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
outpaint_distance_right = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
outpaint_distance_top = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
outpaint_distance_bottom = 0 if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
if refiner_model_name == '':
refiner_model_name = 'None'
inpaint_input_image = dict(image=None, mask=None)
inpaint_additional_prompt = None
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
inpaint_additional_prompt = req.inpaint_additional_prompt
input_image = read_input_image(req.input_image)
inpaint_image_size = input_image.shape[:2]
input_mask = HWC3(numpy.zeros(inpaint_image_size, dtype=numpy.uint8))
if req.input_mask is not None:
input_mask = HWC3(read_input_image(req.input_mask))
inpaint_input_image = {
'image': input_image,
'mask': input_mask
}
image_prompts = []
if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
# Auto set mixing_image_prompt_and_inpaint to True
if len(req.image_prompts) > 0 and uov_input_image is not None:
print("[INFO] Mixing image prompt and vary upscale is set to True")
req.advanced_params.mixing_image_prompt_and_vary_upscale = True
elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
print("[INFO] Mixing image prompt and inpaint is set to True")
req.advanced_params.mixing_image_prompt_and_inpaint = True
for img_prompt in req.image_prompts:
if img_prompt.cn_img is not None:
cn_img = read_input_image(img_prompt.cn_img)
if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
image_prompts.append(
(cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
if len(image_prompts) < config.default_controlnet_image_count:
dp = (None, 0.5, 0.6, 'ImagePrompt')
image_prompts += [dp] * (config.default_controlnet_image_count - len(image_prompts))
if isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)):
enhance_checkbox = True
enhance_input_image = read_input_image(req.enhance_input_image)
enhance_uov_method = req.enhance_uov_method
enhance_uov_processing_order = req.enhance_uov_processing_order
enhance_uov_prompt_type = req.enhance_uov_prompt_type
save_final_enhanced_image_only = True
else:
enhance_checkbox = False
enhance_input_image = None
enhance_uov_method = flags.disabled
enhance_uov_processing_order = "Before First Enhancement"
enhance_uov_prompt_type = "Original Prompts"
save_final_enhanced_image_only = False
if not isinstance(req, (ImageEnhanceRequest, ImageEnhanceRequestJson)):
enhance_ctrlnets = [EnhanceCtrlNets()] * config.default_enhance_tabs
else:
enhance_ctrlnets = req.enhance_ctrlnets
advanced_params = None
if req.advanced_params is not None:
adp = req.advanced_params
if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
adp.refiner_swap_method = 'joint'
if adp.sampler_name not in flags.sampler_list:
print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
adp.sampler_name = default_sampler
if adp.scheduler_name not in flags.scheduler_list:
print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
adp.scheduler_name = default_scheduler
if adp.inpaint_engine not in flags.inpaint_engine_versions:
print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
adp.inpaint_engine = default_inpaint_engine_version
advanced_params = adp
return ImageGenerationParams(
prompt=prompt,
negative_prompt=negative_prompt,
style_selections=style_selections,
performance_selection=performance_selection,
aspect_ratios_selection=aspect_ratios_selection,
image_number=image_number,
image_seed=image_seed,
sharpness=sharpness,
guidance_scale=guidance_scale,
base_model_name=base_model_name,
refiner_model_name=refiner_model_name,
refiner_switch=refiner_switch,
loras=loras,
uov_input_image=uov_input_image,
uov_method=uov_method,
upscale_value=upscale_value,
outpaint_selections=outpaint_selections,
outpaint_distance_left=outpaint_distance_left,
outpaint_distance_right=outpaint_distance_right,
outpaint_distance_top=outpaint_distance_top,
outpaint_distance_bottom=outpaint_distance_bottom,
inpaint_input_image=inpaint_input_image,
inpaint_additional_prompt=inpaint_additional_prompt,
enhance_input_image=enhance_input_image,
enhance_checkbox=enhance_checkbox,
enhance_uov_method=enhance_uov_method,
enhance_uov_processing_order=enhance_uov_processing_order,
enhance_uov_prompt_type=enhance_uov_prompt_type,
save_final_enhanced_image_only=save_final_enhanced_image_only,
enhance_ctrlnets=enhance_ctrlnets,
read_wildcards_in_order=req.read_wildcards_in_order,
image_prompts=image_prompts,
advanced_params=advanced_params,
save_meta=req.save_meta,
meta_scheme=req.meta_scheme,
save_name=req.save_name,
save_extension=req.save_extension,
require_base64=req.require_base64,
)
def generate_async_output(
task: QueueTask,
require_step_preview: bool = False) -> AsyncJobResponse:
"""
Generate output for async job
Arguments:
task: QueueTask
require_step_preview: bool
Returns:
AsyncJobResponse
"""
job_stage = AsyncJobStage.running
job_result = None
if task.start_mills == 0:
job_stage = AsyncJobStage.waiting
if task.is_finished:
if task.finish_with_error:
job_stage = AsyncJobStage.error
elif task.task_result is not None:
job_stage = AsyncJobStage.success
job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
result = AsyncJobResponse(
job_id=task.job_id,
job_type=task.task_type,
job_stage=job_stage,
job_progress=task.finish_progress,
job_status=task.task_status,
job_step_preview=task.task_step_preview if require_step_preview else None,
job_result=job_result)
return result
def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
"""
Generate streaming output for image generation results.
Args:
results (List[ImageGenerationResult]): List of image generation results.
Returns:
Response: Streaming response object, bytes image.
"""
if len(results) == 0:
return Response(status_code=500)
result = results[0]
if result.finish_reason == GenerationFinishReason.queue_is_full:
return Response(status_code=409, content=result.finish_reason.value)
if result.finish_reason == GenerationFinishReason.user_cancel:
return Response(status_code=400, content=result.finish_reason.value)
if result.finish_reason == GenerationFinishReason.error:
return Response(status_code=500, content=result.finish_reason.value)
img_bytes = output_file_to_bytesimg(results[0].im)
return Response(img_bytes, media_type='image/png')
def generate_image_result_output(
results: List[ImageGenerationResult],
require_base64: bool) -> List[GeneratedImageResult]:
"""
Generate image result output
Arguments:
results: List[ImageGenerationResult]
require_base64: bool
Returns:
List[GeneratedImageResult]
"""
results = [
GeneratedImageResult(
base64=output_file_to_base64img(item.im) if require_base64 else None,
url=get_file_serve_url(item.im),
seed=str(item.seed),
finish_reason=item.finish_reason
) for item in results
]
return results