yym68686 commited on
Commit
7fb5f96
·
1 Parent(s): 031b517

🐛 Bug: Fix the bug where the model is not persisted to the file after being automatically retrieved.

Browse files
Files changed (2) hide show
  1. main.py +7 -10
  2. utils.py +20 -9
main.py CHANGED
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
18
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
- from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
22
 
23
  from collections import defaultdict
24
  from typing import List, Dict, Union
@@ -1120,8 +1120,6 @@ async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depe
1120
 
1121
  xue_initialize(tailwind=True)
1122
 
1123
- API_YAML_PATH = "./api.yaml"
1124
-
1125
  data_table_columns = [
1126
  # {"label": "Status", "value": "status", "sortable": True},
1127
  {"label": "Provider", "value": "provider", "sortable": True},
@@ -1500,10 +1498,6 @@ def update_row_data(row_id, updated_data):
1500
  index = int(row_id)
1501
  app.state.config["providers"][index] = updated_data
1502
 
1503
- def save_api_yaml():
1504
- with open(API_YAML_PATH, "w", encoding="utf-8") as f:
1505
- yaml.dump(app.state.config, f)
1506
-
1507
  @frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1508
  async def submit_form(
1509
  row_id: str,
@@ -1551,7 +1545,8 @@ async def submit_form(
1551
  update_row_data(row_id, updated_data)
1552
 
1553
  # 保存更新后的配置
1554
- save_api_yaml()
 
1555
 
1556
  return await root()
1557
 
@@ -1564,7 +1559,8 @@ async def duplicate_row(row_id: str):
1564
  app.state.config["providers"].insert(index + 1, new_data)
1565
 
1566
  # 保存更新后的配置
1567
- save_api_yaml()
 
1568
 
1569
  return await root()
1570
 
@@ -1574,7 +1570,8 @@ async def delete_row(row_id: str):
1574
  del app.state.config["providers"][index]
1575
 
1576
  # 保存更新后的配置
1577
- save_api_yaml()
 
1578
 
1579
  return await root()
1580
 
 
18
  from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
19
  from request import get_payload
20
  from response import fetch_response, fetch_response_stream
21
+ from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
22
 
23
  from collections import defaultdict
24
  from typing import List, Dict, Union
 
1120
 
1121
  xue_initialize(tailwind=True)
1122
 
 
 
1123
  data_table_columns = [
1124
  # {"label": "Status", "value": "status", "sortable": True},
1125
  {"label": "Provider", "value": "provider", "sortable": True},
 
1498
  index = int(row_id)
1499
  app.state.config["providers"][index] = updated_data
1500
 
 
 
 
 
1501
  @frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
1502
  async def submit_form(
1503
  row_id: str,
 
1545
  update_row_data(row_id, updated_data)
1546
 
1547
  # 保存更新后的配置
1548
+ if not DISABLE_DATABASE:
1549
+ save_api_yaml(app.state.config)
1550
 
1551
  return await root()
1552
 
 
1559
  app.state.config["providers"].insert(index + 1, new_data)
1560
 
1561
  # 保存更新后的配置
1562
+ if not DISABLE_DATABASE:
1563
+ save_api_yaml(app.state.config)
1564
 
1565
  return await root()
1566
 
 
1570
  del app.state.config["providers"][index]
1571
 
1572
  # 保存更新后的配置
1573
+ if not DISABLE_DATABASE:
1574
+ save_api_yaml(app.state.config)
1575
 
1576
  return await root()
1577
 
utils.py CHANGED
@@ -63,7 +63,18 @@ def update_initial_model(api_url, api):
63
  traceback.print_exc()
64
  return []
65
 
66
- def update_config(config_data):
 
 
 
 
 
 
 
 
 
 
 
67
  for index, provider in enumerate(config_data['providers']):
68
  if provider.get('project_id'):
69
  provider['base_url'] = 'https://aiplatform.googleapis.com/'
@@ -78,7 +89,11 @@ def update_config(config_data):
78
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
79
 
80
  if not provider.get("model"):
81
- provider["model"] = update_initial_model(provider['base_url'], provider['api'])
 
 
 
 
82
 
83
  if provider.get("tools") == None:
84
  provider["tools"] = True
@@ -128,16 +143,12 @@ async def load_config(app=None):
128
  follow_redirects=True, # 自动跟随重定向
129
  )
130
 
131
- from ruamel.yaml import YAML, YAMLError
132
- yaml = YAML()
133
- yaml.preserve_quotes = True
134
- yaml.indent(mapping=2, sequence=4, offset=2)
135
  try:
136
- with open('api.yaml', 'r', encoding='utf-8') as file:
137
  conf = yaml.load(file)
138
 
139
  if conf:
140
- config, api_keys_db, api_list = update_config(conf)
141
  else:
142
  logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
143
  config, api_keys_db, api_list = {}, {}, []
@@ -166,7 +177,7 @@ async def load_config(app=None):
166
  # 更新配置
167
  # logger.info(config_data)
168
  if config_data:
169
- config, api_keys_db, api_list = update_config(config_data)
170
  else:
171
  logger.error(f"Error fetching or parsing config from {config_url}")
172
  config, api_keys_db, api_list = {}, {}, []
 
63
  traceback.print_exc()
64
  return []
65
 
66
+ from ruamel.yaml import YAML, YAMLError
67
+ yaml = YAML()
68
+ yaml.preserve_quotes = True
69
+ yaml.indent(mapping=2, sequence=4, offset=2)
70
+
71
+ API_YAML_PATH = "./api.yaml"
72
+
73
+ def save_api_yaml(config_data):
74
+ with open(API_YAML_PATH, "w", encoding="utf-8") as f:
75
+ yaml.dump(config_data, f)
76
+
77
+ def update_config(config_data, use_config_url=False):
78
  for index, provider in enumerate(config_data['providers']):
79
  if provider.get('project_id'):
80
  provider['base_url'] = 'https://aiplatform.googleapis.com/'
 
89
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
90
 
91
  if not provider.get("model"):
92
+ model_list = update_initial_model(provider['base_url'], provider['api'])
93
+ if model_list:
94
+ provider["model"] = model_list
95
+ if not use_config_url:
96
+ save_api_yaml(config_data)
97
 
98
  if provider.get("tools") == None:
99
  provider["tools"] = True
 
143
  follow_redirects=True, # 自动跟随重定向
144
  )
145
 
 
 
 
 
146
  try:
147
+ with open(API_YAML_PATH, 'r', encoding='utf-8') as file:
148
  conf = yaml.load(file)
149
 
150
  if conf:
151
+ config, api_keys_db, api_list = update_config(conf, use_config_url=False)
152
  else:
153
  logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
154
  config, api_keys_db, api_list = {}, {}, []
 
177
  # 更新配置
178
  # logger.info(config_data)
179
  if config_data:
180
+ config, api_keys_db, api_list = update_config(config_data, use_config_url=True)
181
  else:
182
  logger.error(f"Error fetching or parsing config from {config_url}")
183
  config, api_keys_db, api_list = {}, {}, []