🐛 Bug: Fix the bug where the model is not persisted to the file after being automatically retrieved.
Browse files
main.py
CHANGED
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
|
18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
-
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
|
22 |
|
23 |
from collections import defaultdict
|
24 |
from typing import List, Dict, Union
|
@@ -1120,8 +1120,6 @@ async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depe
|
|
1120 |
|
1121 |
xue_initialize(tailwind=True)
|
1122 |
|
1123 |
-
API_YAML_PATH = "./api.yaml"
|
1124 |
-
|
1125 |
data_table_columns = [
|
1126 |
# {"label": "Status", "value": "status", "sortable": True},
|
1127 |
{"label": "Provider", "value": "provider", "sortable": True},
|
@@ -1500,10 +1498,6 @@ def update_row_data(row_id, updated_data):
|
|
1500 |
index = int(row_id)
|
1501 |
app.state.config["providers"][index] = updated_data
|
1502 |
|
1503 |
-
def save_api_yaml():
|
1504 |
-
with open(API_YAML_PATH, "w", encoding="utf-8") as f:
|
1505 |
-
yaml.dump(app.state.config, f)
|
1506 |
-
|
1507 |
@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
|
1508 |
async def submit_form(
|
1509 |
row_id: str,
|
@@ -1551,7 +1545,8 @@ async def submit_form(
|
|
1551 |
update_row_data(row_id, updated_data)
|
1552 |
|
1553 |
# 保存更新后的配置
|
1554 |
-
|
|
|
1555 |
|
1556 |
return await root()
|
1557 |
|
@@ -1564,7 +1559,8 @@ async def duplicate_row(row_id: str):
|
|
1564 |
app.state.config["providers"].insert(index + 1, new_data)
|
1565 |
|
1566 |
# 保存更新后的配置
|
1567 |
-
|
|
|
1568 |
|
1569 |
return await root()
|
1570 |
|
@@ -1574,7 +1570,8 @@ async def delete_row(row_id: str):
|
|
1574 |
del app.state.config["providers"][index]
|
1575 |
|
1576 |
# 保存更新后的配置
|
1577 |
-
|
|
|
1578 |
|
1579 |
return await root()
|
1580 |
|
|
|
18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
19 |
from request import get_payload
|
20 |
from response import fetch_response, fetch_response_stream
|
21 |
+
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
22 |
|
23 |
from collections import defaultdict
|
24 |
from typing import List, Dict, Union
|
|
|
1120 |
|
1121 |
xue_initialize(tailwind=True)
|
1122 |
|
|
|
|
|
1123 |
data_table_columns = [
|
1124 |
# {"label": "Status", "value": "status", "sortable": True},
|
1125 |
{"label": "Provider", "value": "provider", "sortable": True},
|
|
|
1498 |
index = int(row_id)
|
1499 |
app.state.config["providers"][index] = updated_data
|
1500 |
|
|
|
|
|
|
|
|
|
1501 |
@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
|
1502 |
async def submit_form(
|
1503 |
row_id: str,
|
|
|
1545 |
update_row_data(row_id, updated_data)
|
1546 |
|
1547 |
# 保存更新后的配置
|
1548 |
+
if not DISABLE_DATABASE:
|
1549 |
+
save_api_yaml(app.state.config)
|
1550 |
|
1551 |
return await root()
|
1552 |
|
|
|
1559 |
app.state.config["providers"].insert(index + 1, new_data)
|
1560 |
|
1561 |
# 保存更新后的配置
|
1562 |
+
if not DISABLE_DATABASE:
|
1563 |
+
save_api_yaml(app.state.config)
|
1564 |
|
1565 |
return await root()
|
1566 |
|
|
|
1570 |
del app.state.config["providers"][index]
|
1571 |
|
1572 |
# 保存更新后的配置
|
1573 |
+
if not DISABLE_DATABASE:
|
1574 |
+
save_api_yaml(app.state.config)
|
1575 |
|
1576 |
return await root()
|
1577 |
|
utils.py
CHANGED
@@ -63,7 +63,18 @@ def update_initial_model(api_url, api):
|
|
63 |
traceback.print_exc()
|
64 |
return []
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
for index, provider in enumerate(config_data['providers']):
|
68 |
if provider.get('project_id'):
|
69 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
@@ -78,7 +89,11 @@ def update_config(config_data):
|
|
78 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
|
79 |
|
80 |
if not provider.get("model"):
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
|
83 |
if provider.get("tools") == None:
|
84 |
provider["tools"] = True
|
@@ -128,16 +143,12 @@ async def load_config(app=None):
|
|
128 |
follow_redirects=True, # 自动跟随重定向
|
129 |
)
|
130 |
|
131 |
-
from ruamel.yaml import YAML, YAMLError
|
132 |
-
yaml = YAML()
|
133 |
-
yaml.preserve_quotes = True
|
134 |
-
yaml.indent(mapping=2, sequence=4, offset=2)
|
135 |
try:
|
136 |
-
with open(
|
137 |
conf = yaml.load(file)
|
138 |
|
139 |
if conf:
|
140 |
-
config, api_keys_db, api_list = update_config(conf)
|
141 |
else:
|
142 |
logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
143 |
config, api_keys_db, api_list = {}, {}, []
|
@@ -166,7 +177,7 @@ async def load_config(app=None):
|
|
166 |
# 更新配置
|
167 |
# logger.info(config_data)
|
168 |
if config_data:
|
169 |
-
config, api_keys_db, api_list = update_config(config_data)
|
170 |
else:
|
171 |
logger.error(f"Error fetching or parsing config from {config_url}")
|
172 |
config, api_keys_db, api_list = {}, {}, []
|
|
|
63 |
traceback.print_exc()
|
64 |
return []
|
65 |
|
66 |
+
from ruamel.yaml import YAML, YAMLError
|
67 |
+
yaml = YAML()
|
68 |
+
yaml.preserve_quotes = True
|
69 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
70 |
+
|
71 |
+
API_YAML_PATH = "./api.yaml"
|
72 |
+
|
73 |
+
def save_api_yaml(config_data):
|
74 |
+
with open(API_YAML_PATH, "w", encoding="utf-8") as f:
|
75 |
+
yaml.dump(config_data, f)
|
76 |
+
|
77 |
+
def update_config(config_data, use_config_url=False):
|
78 |
for index, provider in enumerate(config_data['providers']):
|
79 |
if provider.get('project_id'):
|
80 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
|
|
89 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
|
90 |
|
91 |
if not provider.get("model"):
|
92 |
+
model_list = update_initial_model(provider['base_url'], provider['api'])
|
93 |
+
if model_list:
|
94 |
+
provider["model"] = model_list
|
95 |
+
if not use_config_url:
|
96 |
+
save_api_yaml(config_data)
|
97 |
|
98 |
if provider.get("tools") == None:
|
99 |
provider["tools"] = True
|
|
|
143 |
follow_redirects=True, # 自动跟随重定向
|
144 |
)
|
145 |
|
|
|
|
|
|
|
|
|
146 |
try:
|
147 |
+
with open(API_YAML_PATH, 'r', encoding='utf-8') as file:
|
148 |
conf = yaml.load(file)
|
149 |
|
150 |
if conf:
|
151 |
+
config, api_keys_db, api_list = update_config(conf, use_config_url=False)
|
152 |
else:
|
153 |
logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
154 |
config, api_keys_db, api_list = {}, {}, []
|
|
|
177 |
# 更新配置
|
178 |
# logger.info(config_data)
|
179 |
if config_data:
|
180 |
+
config, api_keys_db, api_list = update_config(config_data, use_config_url=True)
|
181 |
else:
|
182 |
logger.error(f"Error fetching or parsing config from {config_url}")
|
183 |
config, api_keys_db, api_list = {}, {}, []
|