banao-tech commited on
Commit
0958bb1
Β·
verified Β·
1 Parent(s): 0aa98a1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -45
main.py CHANGED
@@ -40,7 +40,6 @@ default_command_live = True
40
  index_url = os.environ.get("INDEX_URL", "")
41
  re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
42
 
43
-
44
  def install_dependents(skip: bool = False):
45
  """
46
  Check and install dependencies
@@ -50,50 +49,31 @@ def install_dependents(skip: bool = False):
50
  if skip:
51
  return
52
 
53
- torch_index_url = os.environ.get(
54
- "TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121"
55
- )
56
-
57
- # Check if you need pip install
58
  if not requirements_check():
59
  run_pip("install -r requirements.txt", "requirements")
60
-
61
  if not check_torch_cuda():
62
- run_pip(
63
- f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}",
64
- desc="torch",
65
- )
66
-
67
 
68
  def preload_pipeline():
69
  """Preload pipeline"""
70
  logger.std_info("[Fooocus-API] Preloading pipeline ...")
71
  import modules.default_pipeline as _
72
 
73
-
74
  def prepare_environments(args) -> bool:
75
  """
76
  Prepare environments
77
  Args:
78
  args: command line arguments
79
  """
80
-
81
  if args.base_url is None or len(args.base_url.strip()) == 0:
82
  host = args.host
83
  if host == "0.0.0.0":
84
- host = "127.0.0.1"
85
  args.base_url = f"http://{host}:{args.port}"
86
 
87
  sys.argv = [sys.argv[0]]
88
 
89
- # Define preset folder paths but avoid runtime file operations
90
- origin_preset_folder = os.path.abspath(os.path.join(module_path, "presets"))
91
- preset_folder = os.path.abspath(os.path.join(script_path, "presets"))
92
- # Comment out file operations to avoid permission issues on Hugging Face Spaces
93
- # if os.path.exists(preset_folder):
94
- # shutil.rmtree(preset_folder)
95
- # shutil.copytree(origin_preset_folder, preset_folder)
96
-
97
  from modules import config
98
  from fooocusapi.configs import default
99
  from fooocusapi.utils.model_loader import download_models
@@ -106,17 +86,12 @@ def prepare_environments(args) -> bool:
106
  default.default_loras = config.default_loras
107
  default.default_cfg_scale = config.default_cfg_scale
108
  default.default_prompt_negative = config.default_prompt_negative
109
- default.default_aspect_ratio = default.get_aspect_ratio_value(
110
- config.default_aspect_ratio
111
- )
112
- default.available_aspect_ratios = [
113
- default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios
114
- ]
115
 
116
  if not args.disable_preset_download:
117
  download_models()
118
 
119
- # Init task queue
120
  from fooocusapi import worker
121
  from fooocusapi.task_queue import TaskQueue
122
 
@@ -130,18 +105,13 @@ def prepare_environments(args) -> bool:
130
  logger.std_info(f"[Fooocus-API] Task queue size: {args.queue_size}")
131
  logger.std_info(f"[Fooocus-API] Queue history size: {args.queue_history}")
132
  logger.std_info(f"[Fooocus-API] Webhook url: {args.webhook_url}")
 
133
 
134
  return True
135
 
136
-
137
  def pre_setup():
138
- """
139
- Pre setup, for replicate
140
- """
141
  class Args(object):
142
- """
143
- Arguments object
144
- """
145
  host = "127.0.0.1"
146
  port = 7860
147
  base_url = None
@@ -160,24 +130,19 @@ def pre_setup():
160
  apikey = None
161
 
162
  print("[Pre Setup] Prepare environments")
163
-
164
  arguments = Args()
165
  sys.argv = [sys.argv[0]]
166
  sys.argv.append("--disable-image-log")
167
 
168
  install_dependents(arguments.skip_pip)
169
-
170
  prepare_environments(arguments)
171
 
172
- # Start task schedule thread
173
  from fooocusapi.worker import task_schedule_loop
174
-
175
  task_thread = Thread(target=task_schedule_loop, daemon=True)
176
  task_thread.start()
177
 
178
  print("[Pre Setup] Finished")
179
 
180
-
181
  if __name__ == "__main__":
182
  logger.std_info(f"[Fooocus API] Python {sys.version}")
183
  logger.std_info(f"[Fooocus API] Fooocus API version: {version}")
@@ -186,6 +151,7 @@ if __name__ == "__main__":
186
 
187
  parser = argparse.ArgumentParser()
188
  add_base_args(parser, True)
 
189
 
190
  args, _ = parser.parse_known_args()
191
  install_dependents(skip=args.skip_pip)
@@ -201,11 +167,29 @@ if __name__ == "__main__":
201
 
202
  # Start task schedule thread
203
  from fooocusapi.worker import task_schedule_loop
204
-
205
  task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
206
  task_schedule_thread.start()
207
 
208
- # Start api server
209
  from fooocusapi.api import start_app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- start_app(args)
 
 
40
  index_url = os.environ.get("INDEX_URL", "")
41
  re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
42
 
 
43
  def install_dependents(skip: bool = False):
44
  """
45
  Check and install dependencies
 
49
  if skip:
50
  return
51
 
52
+ torch_index_url = os.environ.get("TORCH_INDEX_URL", "https://download.pytorch.org/whl/cu121")
 
 
 
 
53
  if not requirements_check():
54
  run_pip("install -r requirements.txt", "requirements")
 
55
  if not check_torch_cuda():
56
+ run_pip(f"install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}", desc="torch")
 
 
 
 
57
 
58
  def preload_pipeline():
59
  """Preload pipeline"""
60
  logger.std_info("[Fooocus-API] Preloading pipeline ...")
61
  import modules.default_pipeline as _
62
 
 
63
  def prepare_environments(args) -> bool:
64
  """
65
  Prepare environments
66
  Args:
67
  args: command line arguments
68
  """
 
69
  if args.base_url is None or len(args.base_url.strip()) == 0:
70
  host = args.host
71
  if host == "0.0.0.0":
72
+ host = "127.0.0.1" # For base_url display, use localhost
73
  args.base_url = f"http://{host}:{args.port}"
74
 
75
  sys.argv = [sys.argv[0]]
76
 
 
 
 
 
 
 
 
 
77
  from modules import config
78
  from fooocusapi.configs import default
79
  from fooocusapi.utils.model_loader import download_models
 
86
  default.default_loras = config.default_loras
87
  default.default_cfg_scale = config.default_cfg_scale
88
  default.default_prompt_negative = config.default_prompt_negative
89
+ default.default_aspect_ratio = default.get_aspect_ratio_value(config.default_aspect_ratio)
90
+ default.available_aspect_ratios = [default.get_aspect_ratio_value(a) for a in config.available_aspect_ratios]
 
 
 
 
91
 
92
  if not args.disable_preset_download:
93
  download_models()
94
 
 
95
  from fooocusapi import worker
96
  from fooocusapi.task_queue import TaskQueue
97
 
 
105
  logger.std_info(f"[Fooocus-API] Task queue size: {args.queue_size}")
106
  logger.std_info(f"[Fooocus-API] Queue history size: {args.queue_history}")
107
  logger.std_info(f"[Fooocus-API] Webhook url: {args.webhook_url}")
108
+ logger.std_info(f"[Fooocus-API] Base URL: {args.base_url}")
109
 
110
  return True
111
 
 
112
  def pre_setup():
113
+ """Pre setup, for replicate"""
 
 
114
  class Args(object):
 
 
 
115
  host = "127.0.0.1"
116
  port = 7860
117
  base_url = None
 
130
  apikey = None
131
 
132
  print("[Pre Setup] Prepare environments")
 
133
  arguments = Args()
134
  sys.argv = [sys.argv[0]]
135
  sys.argv.append("--disable-image-log")
136
 
137
  install_dependents(arguments.skip_pip)
 
138
  prepare_environments(arguments)
139
 
 
140
  from fooocusapi.worker import task_schedule_loop
 
141
  task_thread = Thread(target=task_schedule_loop, daemon=True)
142
  task_thread.start()
143
 
144
  print("[Pre Setup] Finished")
145
 
 
146
  if __name__ == "__main__":
147
  logger.std_info(f"[Fooocus API] Python {sys.version}")
148
  logger.std_info(f"[Fooocus API] Fooocus API version: {version}")
 
151
 
152
  parser = argparse.ArgumentParser()
153
  add_base_args(parser, True)
154
+ parser.set_defaults(host="0.0.0.0") # Default to 0.0.0.0 for broader access
155
 
156
  args, _ = parser.parse_known_args()
157
  install_dependents(skip=args.skip_pip)
 
167
 
168
  # Start task schedule thread
169
  from fooocusapi.worker import task_schedule_loop
 
170
  task_schedule_thread = Thread(target=task_schedule_loop, daemon=True)
171
  task_schedule_thread.start()
172
 
173
+ # Start API server with CORS
174
  from fooocusapi.api import start_app
175
+ from fastapi import FastAPI
176
+ from fastapi.middleware.cors import CORSMiddleware
177
+
178
+ app = FastAPI()
179
+
180
+ # Add CORS middleware
181
+ app.add_middleware(
182
+ CORSMiddleware,
183
+ allow_origins=["*"], # Allow all origins (adjust for production)
184
+ allow_credentials=True,
185
+ allow_methods=["*"],
186
+ allow_headers=["*"],
187
+ )
188
+
189
+ # Add config endpoint to expose base URL
190
+ @app.get("/config")
191
+ async def get_config():
192
+ return {"base_url": args.base_url}
193
 
194
+ # Mount the existing app routes
195
+ start_app(args, app=app)