Spaces:
Runtime error
Runtime error
"""Common models""" | |
from typing import List, Tuple | |
from enum import Enum | |
from fastapi import UploadFile | |
from fastapi.exceptions import RequestValidationError | |
from pydantic import ( | |
ValidationError, | |
ConfigDict, | |
BaseModel, | |
TypeAdapter, | |
Field | |
) | |
from pydantic_core import InitErrorDetails | |
from fooocusapi.configs.default import default_loras | |
class PerformanceSelection(str, Enum): | |
"""Performance selection""" | |
speed = 'Speed' | |
quality = 'Quality' | |
extreme_speed = 'Extreme Speed' | |
lightning = 'Lightning' | |
hyper_sd = 'Hyper-SD' | |
class Lora(BaseModel): | |
"""Common params lora model""" | |
enabled: bool | |
model_name: str | |
weight: float = Field(default=0.5, ge=-2, le=2) | |
model_config = ConfigDict( | |
protected_namespaces=('protect_me_', 'also_protect_') | |
) | |
LoraList = TypeAdapter(List[Lora]) | |
default_loras_model = [] | |
for lora in default_loras: | |
if lora[0] != 'None': | |
default_loras_model.append( | |
Lora( | |
enabled=lora[0], | |
model_name=lora[1], | |
weight=lora[2]) | |
) | |
default_loras_json = LoraList.dump_json(default_loras_model) | |
class MaskModel(str, Enum): | |
"""Inpaint mask model""" | |
u2net = "u2net" | |
u2netp = "u2netp" | |
u2net_human_seg = "u2net_human_seg" | |
u2net_cloth_seg = "u2net_cloth_seg" | |
silueta = "silueta" | |
isnet_general_use = "isnet-general-use" | |
isnet_anime = "isnet-anime" | |
sam = "sam" | |
class EnhanceCtrlNets(BaseModel): | |
enhance_enabled: bool = Field(default=False, description="Enable enhance control nets") | |
enhance_mask_dino_prompt: str = Field(default="face", description="Mask dino prompt, this is necessary, error if no value. usual values: face, eye, mouth, hair, hand, body") | |
enhance_prompt: str = Field(default="", description="Prompt") | |
enhance_negative_prompt: str = Field(default="", description="Negative prompt") | |
enhance_mask_model: MaskModel = Field(default=MaskModel.sam, description="Mask model") | |
enhance_mask_cloth_category: str = Field(default="full", description="Mask cloth category") | |
enhance_mask_sam_model: str = Field(default="vit_b", description="one of vit_b vit_h vit_l") | |
enhance_mask_text_threshold: float = Field(default=0.25, ge=0, le=1, description="Mask text threshold") | |
enhance_mask_box_threshold: float = Field(default=0.3, ge=0, le=1, description="Mask box threshold") | |
enhance_mask_sam_max_detections: int = Field(default=0, ge=0, le=10, description="Mask sam max detections, Set to 0 to detect all") | |
enhance_inpaint_disable_initial_latent: bool = Field(default=False, description="Inpaint disable initial latent") | |
enhance_inpaint_engine: str = Field(default="v2.6", description="Inpaint engine") | |
enhance_inpaint_strength: float = Field(default=1, ge=0, le=1, description="Inpaint strength") | |
enhance_inpaint_respective_field: float = Field(default=0.618, ge=0, le=1, description="Inpaint respective field") | |
enhance_inpaint_erode_or_dilate: float = Field(default=0, ge=-64, le=64, description="Inpaint erode or dilate") | |
enhance_mask_invert: bool = Field(default=False, description="Inpaint mask invert") | |
class GenerateMaskRequest(BaseModel): | |
""" | |
generate mask request | |
""" | |
image: str = Field(description="Image url or base64") | |
mask_model: MaskModel = Field(default=MaskModel.isnet_general_use, description="Mask model") | |
cloth_category: str = Field(default="full", description="Mask cloth category") | |
dino_prompt_text: str = Field(default="", description="Detection prompt, Use singular whenever possible") | |
sam_model: str = Field(default="vit_b", description="one of vit_b vit_h vit_l") | |
box_threshold: float = Field(default=0.3, ge=0, le=1, description="Mask box threshold") | |
text_threshold: float = Field(default=0.25, ge=0, le=1, description="Mask text threshold") | |
sam_max_detections: int = Field(default=0, ge=0, le=10, description="Mask sam max detections, Set to 0 to detect all") | |
dino_erode_or_dilate: float = Field(default=0, ge=-64, le=64, description="Mask dino erode or dilate") | |
dino_debug: bool = Field(default=False, description="Mask dino debug") | |
class UpscaleOrVaryMethod(str, Enum): | |
"""Upscale or Vary method""" | |
disabled = 'Disabled' | |
subtle_variation = 'Vary (Subtle)' | |
strong_variation = 'Vary (Strong)' | |
upscale_15 = 'Upscale (1.5x)' | |
upscale_2 = 'Upscale (2x)' | |
upscale_fast = 'Upscale (Fast 2x)' | |
upscale_custom = 'Upscale (Custom)' | |
class OutpaintExpansion(str, Enum): | |
"""Outpaint expansion""" | |
left = 'Left' | |
right = 'Right' | |
top = 'Top' | |
bottom = 'Bottom' | |
class ControlNetType(str, Enum): | |
"""ControlNet Type""" | |
cn_ip = "ImagePrompt" | |
cn_ip_face = "FaceSwap" | |
cn_canny = "PyraCanny" | |
cn_cpds = "CPDS" | |
class ImagePrompt(BaseModel): | |
"""Common params object ImagePrompt""" | |
cn_img: UploadFile | None = Field(default=None) | |
cn_stop: float | None = Field(default=0.5, ge=0, le=1) | |
cn_weight: float | None = Field(default=0.6, ge=0, le=2, description="None for default value") | |
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip) | |
class DescribeImageType(str, Enum): | |
"""Image type for image to prompt""" | |
photo = 'Photo' | |
anime = 'Anime' | |
class ImageMetaScheme(str, Enum): | |
"""Scheme for save image meta | |
Attributes: | |
Fooocus: json format | |
A111: string | |
""" | |
Fooocus = 'fooocus' | |
A111 = 'a111' | |
def style_selection_parser(style_selections: str | List[str]) -> List[str]: | |
""" | |
Parse style selections, Convert to list | |
Args: | |
style_selections: str, comma separated Fooocus style selections | |
e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp | |
Returns: | |
List[str] | |
""" | |
style_selection_arr: List[str] = [] | |
if style_selections is None or len(style_selections) == 0: | |
return [] | |
for part in style_selections: | |
if len(part) > 0: | |
for s in part.split(','): | |
style = s.strip() | |
style_selection_arr.append(style) | |
return style_selection_arr | |
def lora_parser(loras: str) -> List[Lora]: | |
""" | |
Parse lora config, Convert to list | |
Args: | |
loras: a json string for loras | |
Returns: | |
List[Lora] | |
""" | |
loras_model: List[Lora] = [] | |
if loras is None or len(loras) == 0: | |
return loras_model | |
try: | |
loras_model = LoraList.validate_json(loras) | |
return loras_model | |
except ValidationError as ve: | |
errs = ve.errors() | |
raise RequestValidationError from errs | |
def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]: | |
""" | |
Parse outpaint selections, Convert to list | |
Args: | |
outpaint_selections: str, comma separated Left, Right, Top, Bottom | |
e.g. Left, Right, Top, Bottom | |
Returns: | |
List[OutpaintExpansion] | |
""" | |
outpaint_selections_arr: List[OutpaintExpansion] = [] | |
if outpaint_selections is None or len(outpaint_selections) == 0: | |
return [] | |
for part in outpaint_selections: | |
if len(part) > 0: | |
for s in part.split(','): | |
try: | |
expansion = OutpaintExpansion(s) | |
outpaint_selections_arr.append(expansion) | |
except ValueError: | |
errs = InitErrorDetails( | |
type='enum', | |
loc=tuple('outpaint_selections'), | |
input=outpaint_selections, | |
ctx={ | |
'expected': "str, comma separated Left, Right, Top, Bottom" | |
}) | |
raise RequestValidationError from errs | |
return outpaint_selections_arr | |
def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]: | |
""" | |
Image prompt parser, Convert to List[ImagePrompt] | |
Args: | |
image_prompts_config: List[Tuple] | |
e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal') | |
returns: | |
List[ImagePrompt] | |
""" | |
image_prompts: List[ImagePrompt] = [] | |
if image_prompts_config is None or len(image_prompts_config) == 0: | |
return [] | |
for config in image_prompts_config: | |
cn_img, cn_stop, cn_weight, cn_type = config | |
image_prompts.append(ImagePrompt( | |
cn_img=cn_img, | |
cn_stop=cn_stop, | |
cn_weight=cn_weight, | |
cn_type=cn_type)) | |
return image_prompts | |