Sanket17's picture
added all files
5fbd25d
raw
history blame
8.86 kB
"""Generate API V2 routes
"""
from typing import List
from fastapi import APIRouter, Depends, Header, Query
from fooocusapi.models.common.base import EnhanceCtrlNets, GenerateMaskRequest
from fooocusapi.utils.api_utils import api_key_auth
from fooocusapi.models.requests_v1 import ImagePrompt
from fooocusapi.models.requests_v2 import (
ImageEnhanceRequestJson, ImgInpaintOrOutpaintRequestJson,
ImgPromptRequestJson,
Text2ImgRequestWithPrompt,
ImgUpscaleOrVaryRequestJson
)
from fooocusapi.models.common.response import (
AsyncJobResponse,
GeneratedImageResult
)
from fooocusapi.utils.call_worker import (
call_worker,
generate_mask as gm
)
from fooocusapi.utils.img_utils import base64_to_stream
from fooocusapi.configs.default import img_generate_responses
secure_router = APIRouter(
dependencies=[Depends(api_key_auth)]
)
@secure_router.post(
path="/v2/generation/text-to-image-with-ip",
response_model=List[GeneratedImageResult] | AsyncJobResponse,
responses=img_generate_responses,
tags=["GenerateV2"])
def text_to_img_with_ip(
req: Text2ImgRequestWithPrompt,
accept: str = Header(None),
accept_query: str | None = Query(
default=None, alias='accept',
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
"""\nText to image with prompt\n
Text to image with prompt
Arguments:
req {Text2ImgRequestWithPrompt} -- Text to image generation request
accept {str} -- Accept header
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
Returns:
Response -- img_generate_responses
"""
if accept_query is not None and len(accept_query) > 0:
accept = accept_query
default_image_prompt = ImagePrompt(cn_img=None)
image_prompts_files: List[ImagePrompt] = []
for image_prompt in req.image_prompts:
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
image = ImagePrompt(
cn_img=image_prompt.cn_img,
cn_stop=image_prompt.cn_stop,
cn_weight=image_prompt.cn_weight,
cn_type=image_prompt.cn_type)
image_prompts_files.append(image)
while len(image_prompts_files) <= 4:
image_prompts_files.append(default_image_prompt)
req.image_prompts = image_prompts_files
return call_worker(req, accept)
@secure_router.post(
path="/v2/generation/image-upscale-vary",
response_model=List[GeneratedImageResult] | AsyncJobResponse,
responses=img_generate_responses,
tags=["GenerateV2"])
def img_upscale_or_vary(
req: ImgUpscaleOrVaryRequestJson,
accept: str = Header(None),
accept_query: str | None = Query(
None, alias='accept', description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
"""\nImage upscale or vary\n
Image upscale or vary
Arguments:
req {ImgUpscaleOrVaryRequestJson} -- Image upscale or vary request
accept {str} -- Accept header
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
Returns:
Response -- img_generate_responses
"""
if accept_query is not None and len(accept_query) > 0:
accept = accept_query
req.input_image = base64_to_stream(req.input_image)
default_image_prompt = ImagePrompt(cn_img=None)
image_prompts_files: List[ImagePrompt] = []
for image_prompt in req.image_prompts:
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
image = ImagePrompt(
cn_img=image_prompt.cn_img,
cn_stop=image_prompt.cn_stop,
cn_weight=image_prompt.cn_weight,
cn_type=image_prompt.cn_type)
image_prompts_files.append(image)
while len(image_prompts_files) <= 4:
image_prompts_files.append(default_image_prompt)
req.image_prompts = image_prompts_files
return call_worker(req, accept)
@secure_router.post(
path="/v2/generation/image-inpaint-outpaint",
response_model=List[GeneratedImageResult] | AsyncJobResponse,
responses=img_generate_responses,
tags=["GenerateV2"])
def img_inpaint_or_outpaint(
req: ImgInpaintOrOutpaintRequestJson,
accept: str = Header(None),
accept_query: str | None = Query(
None, alias='accept',
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
"""\nInpaint or outpaint\n
Inpaint or outpaint
Arguments:
req {ImgInpaintOrOutpaintRequestJson} -- Request body
accept {str} -- Accept header
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
Returns:
Response -- img_generate_responses
"""
if accept_query is not None and len(accept_query) > 0:
accept = accept_query
req.input_image = base64_to_stream(req.input_image)
if req.input_mask is not None:
req.input_mask = base64_to_stream(req.input_mask)
default_image_prompt = ImagePrompt(cn_img=None)
image_prompts_files: List[ImagePrompt] = []
for image_prompt in req.image_prompts:
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
image = ImagePrompt(
cn_img=image_prompt.cn_img,
cn_stop=image_prompt.cn_stop,
cn_weight=image_prompt.cn_weight,
cn_type=image_prompt.cn_type)
image_prompts_files.append(image)
while len(image_prompts_files) <= 4:
image_prompts_files.append(default_image_prompt)
req.image_prompts = image_prompts_files
return call_worker(req, accept)
@secure_router.post(
path="/v2/generation/image-prompt",
response_model=List[GeneratedImageResult] | AsyncJobResponse,
responses=img_generate_responses,
tags=["GenerateV2"])
def img_prompt(
req: ImgPromptRequestJson,
accept: str = Header(None),
accept_query: str | None = Query(
None, alias='accept',
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
"""\nImage prompt\n
Image prompt generation
Arguments:
req {ImgPromptRequest} -- Request body
accept {str} -- Accept header
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
Returns:
Response -- img_generate_responses
"""
if accept_query is not None and len(accept_query) > 0:
accept = accept_query
if req.input_image is not None:
req.input_image = base64_to_stream(req.input_image)
if req.input_mask is not None:
req.input_mask = base64_to_stream(req.input_mask)
default_image_prompt = ImagePrompt(cn_img=None)
image_prompts_files: List[ImagePrompt] = []
for image_prompt in req.image_prompts:
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
image = ImagePrompt(
cn_img=image_prompt.cn_img,
cn_stop=image_prompt.cn_stop,
cn_weight=image_prompt.cn_weight,
cn_type=image_prompt.cn_type)
image_prompts_files.append(image)
while len(image_prompts_files) <= 4:
image_prompts_files.append(default_image_prompt)
req.image_prompts = image_prompts_files
return call_worker(req, accept)
@secure_router.post(
path="/v2/generation/image-enhance",
response_model=List[GeneratedImageResult] | AsyncJobResponse,
responses=img_generate_responses,
tags=["GenerateV2"])
def img_enhance(
req: ImageEnhanceRequestJson,
accept: str = Header(None),
accept_query: str | None = Query(
None, alias='accept',
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
"""\nImage prompt\n
Image prompt generation
Arguments:
req {ImageEnhanceRequestJson} -- Request body
accept {str} -- Accept header
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
Returns:
Response -- img_generate_responses
"""
if accept_query is not None and len(accept_query) > 0:
accept = accept_query
if req.enhance_input_image is not None:
req.enhance_input_image = base64_to_stream(req.enhance_input_image)
if len(req.enhance_ctrlnets) < 3:
default_enhance_ctrlnet = [EnhanceCtrlNets()]
req.enhance_ctrlnets + (default_enhance_ctrlnet * (4 - len(req.enhance_ctrlnets)))
return call_worker(req, accept)
@secure_router.post(
path="/v1/tools/generate_mask",
summary="Generate mask endpoint",
tags=["GenerateV1"])
async def generate_mask(mask_options: GenerateMaskRequest) -> str:
"""
Generate mask endpoint
"""
return await gm(request=mask_options)