Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" Entry for Fooocus API. | |
Use for starting Fooocus API. | |
python main.py --help for more usage | |
@file: main.py | |
@author: Konie | |
@update: 2024-03-22 | |
""" | |
import argparse | |
import os | |
import re | |
import sys | |
from threading import Thread | |
from fooocusapi.utils.logger import logger | |
from fooocusapi.utils.tools import run_pip, check_torch_cuda, requirements_check | |
from fooocus_api_version import version | |
script_path = os.path.dirname(os.path.realpath(__file__)) | |
module_path = os.path.join(script_path, "repositories/Fooocus") | |
sys.path.append(script_path) | |
sys.path.append(module_path) | |
logger.std_info("[System ARGV] " + str(sys.argv)) | |
try: | |
index = sys.argv.index('--gpu-device-id') | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(sys.argv[index+1]) | |
logger.std_info(f"[Fooocus] Set device to: {str(sys.argv[index+1])}") | |
except ValueError: | |
pass | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" | |
python = sys.executable | |
default_command_live = True | |
index_url = os.environ.get("INDEX_URL", "") | |
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") | |
def install_dependents(skip: bool = False): | |
""" | |
Check and install dependencies | |
Args: | |
skip: skip pip install | |
""" | |
if skip: | |
return | |
torch_index_url = os.environ.get("TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121") | |
logger.std_info(f"[Fooocus-API] Using torch index URL: {torch_index_url}") | |
# Check if you need pip install | |
if not requirements_check(): | |
logger.std_info("[Fooocus-API] Installing requirements.txt...") | |
run_pip("install -r requirements.txt", "requirements") | |
if not check_torch_cuda(): | |
logger.std_info("[Fooocus-API] Installing PyTorch with CUDA support...") | |
run_pip( | |
f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}", | |
desc="torch", | |
) | |
else: | |
logger.std_info("[Fooocus-API] PyTorch with CUDA already installed") | |
def preload_pipeline(): | |
"""Preload pipeline with detailed error handling""" | |
logger.std_info("[Fooocus-API] Preloading pipeline ...") | |
try: | |
import torch | |
logger.std_info(f"[Fooocus-API] PyTorch version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
logger.std_info(f"[Fooocus-API] CUDA device: {torch.cuda.current_device()}, {torch.cuda.get_device_name(0)}") | |
import modules.default_pipeline as pipeline | |
logger.std_info("[Fooocus-API] Pipeline module imported successfully") | |
# Add more granular steps here if needed to isolate crash | |
except Exception as e: | |
logger.std_error(f"[Fooocus-API] Pipeline preload failed: {str(e)}") | |
raise | |
def prepare_environments(args) -> bool: | |
""" | |
Prepare environments | |
Args: | |
args: command line arguments | |
""" | |
if args.base_url is None or len(args.base_url.strip()) == 0: | |
host = args.host | |
if host == "0.0.0.0": | |
host = "127.0.0.1" # For base_url display | |
args.base_url = f"http://{host}:{args.port}" | |
sys.argv = [sys.argv[0]] | |
# Define preset folder paths but avoid runtime file operations | |
origin_preset_folder = os.path.abspath(os.path.join(module_path, "presets")) | |
preset_folder = os.path.abspath(os.path.join(script_path, "presets")) | |
logger.std_info(f"[Fooocus-API] Origin preset folder: {origin_preset_folder}") | |
logger.std_info(f"[Fooocus-API] Local preset folder: {preset_folder}") | |
# Comment out file operations to avoid permission issues on Hugging Face Spaces | |
# if os.path.exists(preset_folder): | |
# shutil.rmtree(preset_folder) | |
# shutil.copytree(origin_preset_folder, preset_folder) | |
from modules import config | |
from fooocusapi.configs import default | |
from fooocusapi.utils.model_loader import download_models | |
default.default_inpaint_engine_version = config.default_inpaint_engine_version | |
default.default_styles = config.default_styles | |
default.default_base_model_name = config.default_base_model_name | |
default.default_refiner_model_name = config.default_refiner_model_name | |
default.default_refiner_switch = config.default_refiner_switch | |
default.default_loras = config.default_loras | |
default.default_cfg_scale = config.default_cfg_scale | |
default.default_prompt_negative = config.default_prompt_negative | |
default.default_aspect_ratio = default.get_aspect_ratio_value(config.default_aspect_ratio) | |
default.available_aspect_ratios = [default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios] | |
if not args.disable_preset_download: | |
logger.std_info("[Fooocus-API] Downloading models...") | |
download_models() | |
logger.std_info("[Fooocus-API] Model download completed") | |
# Init task queue | |
from fooocusapi import worker | |
from fooocusapi.task_queue import TaskQueue | |
worker.worker_queue = TaskQueue( | |
queue_size=args.queue_size, | |
history_size=args.queue_history, | |
webhook_url=args.webhook_url, | |
persistent=args.persistent, | |
) | |
logger.std_info(f"[Fooocus-API] Task queue size: {args.queue_size}") | |
logger.std_info(f"[Fooocus-API] Queue history size: {args.queue_history}") | |
logger.std_info(f"[Fooocus-API] Webhook url: {args.webhook_url}") | |
logger.std_info(f"[Fooocus-API] Base URL: {args.base_url}") | |
return True | |
def pre_setup(): | |
""" | |
Pre setup, for replicate or Hugging Face Spaces | |
""" | |
class Args(object): | |
""" | |
Arguments object | |
""" | |
host = "127.0.0.1" | |
port = 7860 | |
base_url = None | |
sync_repo = "skip" | |
disable_image_log = True | |
skip_pip = True | |
preload_pipeline = True | |
queue_size = 100 | |
queue_history = 0 | |
preset = "default" | |
webhook_url = None | |
persistent = False | |
always_gpu = False | |
all_in_fp16 = False | |
gpu_device_id = None | |
apikey = None | |
logger.std_info("[Pre Setup] Preparing environments") | |
arguments = Args() | |
sys.argv = [sys.argv[0]] | |
sys.argv.append("--disable-image-log") | |
install_dependents(arguments.skip_pip) | |
prepare_environments(arguments) | |
from fooocusapi.worker import task_schedule_loop | |
task_thread = Thread(target=task_schedule_loop, daemon=True) | |
task_thread.start() | |
logger.std_info("[Pre Setup] Finished") | |
if __name__ == "__main__": | |
logger.std_info(f"[Fooocus API] Python {sys.version}") | |
logger.std_info(f"[Fooocus API] Fooocus API version: {version}") | |
from fooocusapi.base_args import add_base_args | |
parser = argparse.ArgumentParser() | |
add_base_args(parser, True) | |
parser.set_defaults(host="0.0.0.0") # Default to 0.0.0.0 for broader access | |
args, _ = parser.parse_known_args() | |
install_dependents(skip=args.skip_pip) | |
from fooocusapi.args import args | |
if prepare_environments(args): | |
sys.argv = [sys.argv[0]] | |
# Load pipeline in new thread with error handling | |
preload_pipeline_thread = Thread(target=preload_pipeline, daemon=True) | |
preload_pipeline_thread.start() | |
# Start task schedule thread | |
from fooocusapi.worker import task_schedule_loop | |
task_schedule_thread = Thread(target=task_schedule_loop, daemon=True) | |
task_schedule_thread.start() | |
# Start API server using original Fooocus API call | |
from fooocusapi.api import start_app | |
try: | |
logger.std_info("[Fooocus-API] Starting API server...") | |
start_app(args) | |
except Exception as e: | |
logger.std_error(f"[Fooocus-API] Failed to start API server: {str(e)}") | |
raise |