yym68686 commited on
Commit
83b5b1b
·
1 Parent(s): e13d491

🐛 Bug: Fix the bug of infinite loop error when there is only one weighted channel.

Browse files
Files changed (1) hide show
  1. main.py +9 -1
main.py CHANGED
@@ -1036,6 +1036,8 @@ async def get_right_order_providers(request_model, config, api_index, scheduling
1036
  # 步骤 3: 计算交集
1037
  intersection = all_providers.intersection(weight_keys)
1038
  # print("intersection", intersection)
 
 
1039
 
1040
  if intersection:
1041
  filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
@@ -1097,6 +1099,10 @@ class ModelRequestHandler:
1097
  retry_count = 0
1098
 
1099
  while True:
 
 
 
 
1100
  if index >= num_matching_providers + retry_count:
1101
  break
1102
  current_index = (start_index + index) % num_matching_providers
@@ -1136,8 +1142,10 @@ class ModelRequestHandler:
1136
  # source_model = list(provider['model'][0].keys())[0]
1137
  await app.state.channel_manager.exclude_model(channel_id, request_model)
1138
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
 
1139
  num_matching_providers = len(matching_providers)
1140
- index = 0
 
1141
 
1142
  cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0)
1143
  api_key_count = provider_api_circular_list[channel_id].get_items_count()
 
1036
  # 步骤 3: 计算交集
1037
  intersection = all_providers.intersection(weight_keys)
1038
  # print("intersection", intersection)
1039
+ if len(intersection) == 1:
1040
+ intersection = None
1041
 
1042
  if intersection:
1043
  filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
 
1099
  retry_count = 0
1100
 
1101
  while True:
1102
+ # print("start_index", start_index)
1103
+ # print("index", index)
1104
+ # print("num_matching_providers", num_matching_providers)
1105
+ # print("retry_count", retry_count)
1106
  if index >= num_matching_providers + retry_count:
1107
  break
1108
  current_index = (start_index + index) % num_matching_providers
 
1142
  # source_model = list(provider['model'][0].keys())[0]
1143
  await app.state.channel_manager.exclude_model(channel_id, request_model)
1144
  matching_providers = await get_right_order_providers(request_model, config, api_index, scheduling_algorithm)
1145
+ last_num_matching_providers = num_matching_providers
1146
  num_matching_providers = len(matching_providers)
1147
+ if num_matching_providers != last_num_matching_providers:
1148
+ index = 0
1149
 
1150
  cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0)
1151
  api_key_count = provider_api_circular_list[channel_id].get_items_count()