Sanket17's picture
added all files
5fbd25d
"""Common models"""
from typing import List, Tuple
from enum import Enum
from fastapi import UploadFile
from fastapi.exceptions import RequestValidationError
from pydantic import (
ValidationError,
ConfigDict,
BaseModel,
TypeAdapter,
Field
)
from pydantic_core import InitErrorDetails
from fooocusapi.configs.default import default_loras
class PerformanceSelection(str, Enum):
"""Performance selection"""
speed = 'Speed'
quality = 'Quality'
extreme_speed = 'Extreme Speed'
lightning = 'Lightning'
hyper_sd = 'Hyper-SD'
class Lora(BaseModel):
"""Common params lora model"""
enabled: bool
model_name: str
weight: float = Field(default=0.5, ge=-2, le=2)
model_config = ConfigDict(
protected_namespaces=('protect_me_', 'also_protect_')
)
LoraList = TypeAdapter(List[Lora])
default_loras_model = []
for lora in default_loras:
if lora[0] != 'None':
default_loras_model.append(
Lora(
enabled=lora[0],
model_name=lora[1],
weight=lora[2])
)
default_loras_json = LoraList.dump_json(default_loras_model)
class MaskModel(str, Enum):
"""Inpaint mask model"""
u2net = "u2net"
u2netp = "u2netp"
u2net_human_seg = "u2net_human_seg"
u2net_cloth_seg = "u2net_cloth_seg"
silueta = "silueta"
isnet_general_use = "isnet-general-use"
isnet_anime = "isnet-anime"
sam = "sam"
class EnhanceCtrlNets(BaseModel):
enhance_enabled: bool = Field(default=False, description="Enable enhance control nets")
enhance_mask_dino_prompt: str = Field(default="face", description="Mask dino prompt, this is necessary, error if no value. usual values: face, eye, mouth, hair, hand, body")
enhance_prompt: str = Field(default="", description="Prompt")
enhance_negative_prompt: str = Field(default="", description="Negative prompt")
enhance_mask_model: MaskModel = Field(default=MaskModel.sam, description="Mask model")
enhance_mask_cloth_category: str = Field(default="full", description="Mask cloth category")
enhance_mask_sam_model: str = Field(default="vit_b", description="one of vit_b vit_h vit_l")
enhance_mask_text_threshold: float = Field(default=0.25, ge=0, le=1, description="Mask text threshold")
enhance_mask_box_threshold: float = Field(default=0.3, ge=0, le=1, description="Mask box threshold")
enhance_mask_sam_max_detections: int = Field(default=0, ge=0, le=10, description="Mask sam max detections, Set to 0 to detect all")
enhance_inpaint_disable_initial_latent: bool = Field(default=False, description="Inpaint disable initial latent")
enhance_inpaint_engine: str = Field(default="v2.6", description="Inpaint engine")
enhance_inpaint_strength: float = Field(default=1, ge=0, le=1, description="Inpaint strength")
enhance_inpaint_respective_field: float = Field(default=0.618, ge=0, le=1, description="Inpaint respective field")
enhance_inpaint_erode_or_dilate: float = Field(default=0, ge=-64, le=64, description="Inpaint erode or dilate")
enhance_mask_invert: bool = Field(default=False, description="Inpaint mask invert")
class GenerateMaskRequest(BaseModel):
"""
generate mask request
"""
image: str = Field(description="Image url or base64")
mask_model: MaskModel = Field(default=MaskModel.isnet_general_use, description="Mask model")
cloth_category: str = Field(default="full", description="Mask cloth category")
dino_prompt_text: str = Field(default="", description="Detection prompt, Use singular whenever possible")
sam_model: str = Field(default="vit_b", description="one of vit_b vit_h vit_l")
box_threshold: float = Field(default=0.3, ge=0, le=1, description="Mask box threshold")
text_threshold: float = Field(default=0.25, ge=0, le=1, description="Mask text threshold")
sam_max_detections: int = Field(default=0, ge=0, le=10, description="Mask sam max detections, Set to 0 to detect all")
dino_erode_or_dilate: float = Field(default=0, ge=-64, le=64, description="Mask dino erode or dilate")
dino_debug: bool = Field(default=False, description="Mask dino debug")
class UpscaleOrVaryMethod(str, Enum):
"""Upscale or Vary method"""
disabled = 'Disabled'
subtle_variation = 'Vary (Subtle)'
strong_variation = 'Vary (Strong)'
upscale_15 = 'Upscale (1.5x)'
upscale_2 = 'Upscale (2x)'
upscale_fast = 'Upscale (Fast 2x)'
upscale_custom = 'Upscale (Custom)'
class OutpaintExpansion(str, Enum):
"""Outpaint expansion"""
left = 'Left'
right = 'Right'
top = 'Top'
bottom = 'Bottom'
class ControlNetType(str, Enum):
"""ControlNet Type"""
cn_ip = "ImagePrompt"
cn_ip_face = "FaceSwap"
cn_canny = "PyraCanny"
cn_cpds = "CPDS"
class ImagePrompt(BaseModel):
"""Common params object ImagePrompt"""
cn_img: UploadFile | None = Field(default=None)
cn_stop: float | None = Field(default=0.5, ge=0, le=1)
cn_weight: float | None = Field(default=0.6, ge=0, le=2, description="None for default value")
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip)
class DescribeImageType(str, Enum):
"""Image type for image to prompt"""
photo = 'Photo'
anime = 'Anime'
class ImageMetaScheme(str, Enum):
"""Scheme for save image meta
Attributes:
Fooocus: json format
A111: string
"""
Fooocus = 'fooocus'
A111 = 'a111'
def style_selection_parser(style_selections: str | List[str]) -> List[str]:
"""
Parse style selections, Convert to list
Args:
style_selections: str, comma separated Fooocus style selections
e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp
Returns:
List[str]
"""
style_selection_arr: List[str] = []
if style_selections is None or len(style_selections) == 0:
return []
for part in style_selections:
if len(part) > 0:
for s in part.split(','):
style = s.strip()
style_selection_arr.append(style)
return style_selection_arr
def lora_parser(loras: str) -> List[Lora]:
"""
Parse lora config, Convert to list
Args:
loras: a json string for loras
Returns:
List[Lora]
"""
loras_model: List[Lora] = []
if loras is None or len(loras) == 0:
return loras_model
try:
loras_model = LoraList.validate_json(loras)
return loras_model
except ValidationError as ve:
errs = ve.errors()
raise RequestValidationError from errs
def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]:
"""
Parse outpaint selections, Convert to list
Args:
outpaint_selections: str, comma separated Left, Right, Top, Bottom
e.g. Left, Right, Top, Bottom
Returns:
List[OutpaintExpansion]
"""
outpaint_selections_arr: List[OutpaintExpansion] = []
if outpaint_selections is None or len(outpaint_selections) == 0:
return []
for part in outpaint_selections:
if len(part) > 0:
for s in part.split(','):
try:
expansion = OutpaintExpansion(s)
outpaint_selections_arr.append(expansion)
except ValueError:
errs = InitErrorDetails(
type='enum',
loc=tuple('outpaint_selections'),
input=outpaint_selections,
ctx={
'expected': "str, comma separated Left, Right, Top, Bottom"
})
raise RequestValidationError from errs
return outpaint_selections_arr
def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]:
"""
Image prompt parser, Convert to List[ImagePrompt]
Args:
image_prompts_config: List[Tuple]
e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal')
returns:
List[ImagePrompt]
"""
image_prompts: List[ImagePrompt] = []
if image_prompts_config is None or len(image_prompts_config) == 0:
return []
for config in image_prompts_config:
cn_img, cn_stop, cn_weight, cn_type = config
image_prompts.append(ImagePrompt(
cn_img=cn_img,
cn_stop=cn_stop,
cn_weight=cn_weight,
cn_type=cn_type))
return image_prompts