banao-tech commited on
Commit
c7bb665
Β·
verified Β·
1 Parent(s): 0958bb1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +55 -30
main.py CHANGED
@@ -50,15 +50,37 @@ def install_dependents(skip: bool = False):
50
  return
51
 
52
  torch_index_url = os.environ.get("TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121")
 
 
 
53
  if not requirements_check():
 
54
  run_pip("install -r requirements.txt", "requirements")
 
55
  if not check_torch_cuda():
56
- run_pip(f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}", desc="torch")
 
 
 
 
 
 
57
 
58
  def preload_pipeline():
59
- """Preload pipeline"""
60
  logger.std_info("[Fooocus-API] Preloading pipeline ...")
61
- import modules.default_pipeline as _
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def prepare_environments(args) -> bool:
64
  """
@@ -69,11 +91,21 @@ def prepare_environments(args) -> bool:
69
  if args.base_url is None or len(args.base_url.strip()) == 0:
70
  host = args.host
71
  if host == "0.0.0.0":
72
- host = "127.0.0.1" # For base_url display, use localhost
73
  args.base_url = f"http://{host}:{args.port}"
74
 
75
  sys.argv = [sys.argv[0]]
76
 
 
 
 
 
 
 
 
 
 
 
77
  from modules import config
78
  from fooocusapi.configs import default
79
  from fooocusapi.utils.model_loader import download_models
@@ -90,8 +122,11 @@ def prepare_environments(args) -> bool:
90
  default.available_aspect_ratios = [default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios]
91
 
92
  if not args.disable_preset_download:
 
93
  download_models()
 
94
 
 
95
  from fooocusapi import worker
96
  from fooocusapi.task_queue import TaskQueue
97
 
@@ -110,8 +145,13 @@ def prepare_environments(args) -> bool:
110
  return True
111
 
112
  def pre_setup():
113
- """Pre setup, for replicate"""
 
 
114
  class Args(object):
 
 
 
115
  host = "127.0.0.1"
116
  port = 7860
117
  base_url = None
@@ -129,7 +169,7 @@ def pre_setup():
129
  gpu_device_id = None
130
  apikey = None
131
 
132
- print("[Pre Setup] Prepare environments")
133
  arguments = Args()
134
  sys.argv = [sys.argv[0]]
135
  sys.argv.append("--disable-image-log")
@@ -141,7 +181,7 @@ def pre_setup():
141
  task_thread = Thread(target=task_schedule_loop, daemon=True)
142
  task_thread.start()
143
 
144
- print("[Pre Setup] Finished")
145
 
146
  if __name__ == "__main__":
147
  logger.std_info(f"[Fooocus API] Python {sys.version}")
@@ -161,7 +201,7 @@ if __name__ == "__main__":
161
  if prepare_environments(args):
162
  sys.argv = [sys.argv[0]]
163
 
164
- # Load pipeline in new thread
165
  preload_pipeline_thread = Thread(target=preload_pipeline, daemon=True)
166
  preload_pipeline_thread.start()
167
 
@@ -170,26 +210,11 @@ if __name__ == "__main__":
170
  task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
171
  task_schedule_thread.start()
172
 
173
- # Start API server with CORS
174
  from fooocusapi.api import start_app
175
- from fastapi import FastAPI
176
- from fastapi.middleware.cors import CORSMiddleware
177
-
178
- app = FastAPI()
179
-
180
- # Add CORS middleware
181
- app.add_middleware(
182
- CORSMiddleware,
183
- allow_origins=["*"], # Allow all origins (adjust for production)
184
- allow_credentials=True,
185
- allow_methods=["*"],
186
- allow_headers=["*"],
187
- )
188
-
189
- # Add config endpoint to expose base URL
190
- @app.get("/config")
191
- async def get_config():
192
- return {"base_url": args.base_url}
193
-
194
- # Mount the existing app routes
195
- start_app(args, app=app)
 
50
  return
51
 
52
  torch_index_url = os.environ.get("TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121")
53
+ logger.std_info(f"[Fooocus-API] Using torch index URL: {torch_index_url}")
54
+
55
+ # Check if you need pip install
56
  if not requirements_check():
57
+ logger.std_info("[Fooocus-API] Installing requirements.txt...")
58
  run_pip("install -r requirements.txt", "requirements")
59
+
60
  if not check_torch_cuda():
61
+ logger.std_info("[Fooocus-API] Installing PyTorch with CUDA support...")
62
+ run_pip(
63
+ f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}",
64
+ desc="torch",
65
+ )
66
+ else:
67
+ logger.std_info("[Fooocus-API] PyTorch with CUDA already installed")
68
 
69
  def preload_pipeline():
70
+ """Preload pipeline with detailed error handling"""
71
  logger.std_info("[Fooocus-API] Preloading pipeline ...")
72
+ try:
73
+ import torch
74
+ logger.std_info(f"[Fooocus-API] PyTorch version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}")
75
+ if torch.cuda.is_available():
76
+ logger.std_info(f"[Fooocus-API] CUDA device: {torch.cuda.current_device()}, {torch.cuda.get_device_name(0)}")
77
+
78
+ import modules.default_pipeline as pipeline
79
+ logger.std_info("[Fooocus-API] Pipeline module imported successfully")
80
+ # Add more granular steps here if needed to isolate crash
81
+ except Exception as e:
82
+ logger.std_error(f"[Fooocus-API] Pipeline preload failed: {str(e)}")
83
+ raise
84
 
85
  def prepare_environments(args) -> bool:
86
  """
 
91
  if args.base_url is None or len(args.base_url.strip()) == 0:
92
  host = args.host
93
  if host == "0.0.0.0":
94
+ host = "127.0.0.1" # For base_url display
95
  args.base_url = f"http://{host}:{args.port}"
96
 
97
  sys.argv = [sys.argv[0]]
98
 
99
+ # Define preset folder paths but avoid runtime file operations
100
+ origin_preset_folder = os.path.abspath(os.path.join(module_path, "presets"))
101
+ preset_folder = os.path.abspath(os.path.join(script_path, "presets"))
102
+ logger.std_info(f"[Fooocus-API] Origin preset folder: {origin_preset_folder}")
103
+ logger.std_info(f"[Fooocus-API] Local preset folder: {preset_folder}")
104
+ # Comment out file operations to avoid permission issues on Hugging Face Spaces
105
+ # if os.path.exists(preset_folder):
106
+ # shutil.rmtree(preset_folder)
107
+ # shutil.copytree(origin_preset_folder, preset_folder)
108
+
109
  from modules import config
110
  from fooocusapi.configs import default
111
  from fooocusapi.utils.model_loader import download_models
 
122
  default.available_aspect_ratios = [default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios]
123
 
124
  if not args.disable_preset_download:
125
+ logger.std_info("[Fooocus-API] Downloading models...")
126
  download_models()
127
+ logger.std_info("[Fooocus-API] Model download completed")
128
 
129
+ # Init task queue
130
  from fooocusapi import worker
131
  from fooocusapi.task_queue import TaskQueue
132
 
 
145
  return True
146
 
147
  def pre_setup():
148
+ """
149
+ Pre setup, for replicate or Hugging Face Spaces
150
+ """
151
  class Args(object):
152
+ """
153
+ Arguments object
154
+ """
155
  host = "127.0.0.1"
156
  port = 7860
157
  base_url = None
 
169
  gpu_device_id = None
170
  apikey = None
171
 
172
+ logger.std_info("[Pre Setup] Preparing environments")
173
  arguments = Args()
174
  sys.argv = [sys.argv[0]]
175
  sys.argv.append("--disable-image-log")
 
181
  task_thread = Thread(target=task_schedule_loop, daemon=True)
182
  task_thread.start()
183
 
184
+ logger.std_info("[Pre Setup] Finished")
185
 
186
  if __name__ == "__main__":
187
  logger.std_info(f"[Fooocus API] Python {sys.version}")
 
201
  if prepare_environments(args):
202
  sys.argv = [sys.argv[0]]
203
 
204
+ # Load pipeline in new thread with error handling
205
  preload_pipeline_thread = Thread(target=preload_pipeline, daemon=True)
206
  preload_pipeline_thread.start()
207
 
 
210
  task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
211
  task_schedule_thread.start()
212
 
213
+ # Start API server using original Fooocus API call
214
  from fooocusapi.api import start_app
215
+ try:
216
+ logger.std_info("[Fooocus-API] Starting API server...")
217
+ start_app(args)
218
+ except Exception as e:
219
+ logger.std_error(f"[Fooocus-API] Failed to start API server: {str(e)}")
220
+ raise