yym68686 commited on
Commit
5b1ad67
·
1 Parent(s): 7416300

🐛 Bug: Fix the bug where weight polling did not check if the weight channel conforms to the request model.

Browse files
Files changed (2) hide show
  1. main.py +98 -83
  2. utils.py +2 -2
main.py CHANGED
@@ -647,98 +647,105 @@ def lottery_scheduling(weights):
647
  break
648
  return selections
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  import asyncio
651
  class ModelRequestHandler:
652
  def __init__(self):
653
  self.last_provider_indices = defaultdict(lambda: -1)
654
  self.locks = defaultdict(asyncio.Lock)
655
 
656
- def get_matching_providers(self, model_name, token):
657
  config = app.state.config
658
- # api_keys_db = app.state.api_keys_db
659
  api_list = app.state.api_list
660
  api_index = api_list.index(token)
 
661
  if not safe_get(config, 'api_keys', api_index, 'model'):
662
  raise HTTPException(status_code=404, detail="No matching model found")
663
- provider_rules = []
664
-
665
- for model in config['api_keys'][api_index]['model']:
666
- if model == "all":
667
- # 如果模型名为 all,则返回所有模型
668
- for provider in config["providers"]:
669
- model_dict = get_model_dict(provider)
670
- for model in model_dict.keys():
671
- provider_rules.append(provider["provider"] + "/" + model)
672
- break
673
- if "/" in model:
674
- if model.startswith("<") and model.endswith(">"):
675
- model = model[1:-1]
676
- # 处理带斜杠的模型名
677
- for provider in config['providers']:
678
- model_dict = get_model_dict(provider)
679
- if model in model_dict.keys():
680
- provider_rules.append(provider['provider'] + "/" + model)
681
- else:
682
- provider_name = model.split("/")[0]
683
- model_name_split = "/".join(model.split("/")[1:])
684
- models_list = []
685
- for provider in config['providers']:
686
- model_dict = get_model_dict(provider)
687
- if provider['provider'] == provider_name:
688
- models_list.extend(list(model_dict.keys()))
689
- # print("models_list", models_list)
690
- # print("model_name", model_name)
691
- # print("model_name_split", model_name_split)
692
- # print("model", model)
693
-
694
- # api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
695
- if model_name_split == "*":
696
- if model_name in models_list:
697
- provider_rules.append(provider_name + "/" + model_name)
698
-
699
- # 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
700
- for models_list_model in models_list:
701
- if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
702
- provider_rules.append(provider_name + "/" + models_list_model)
703
-
704
- # api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
705
- elif model_name_split == model_name \
706
- or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
707
- if model_name_split in models_list:
708
- provider_rules.append(provider_name + "/" + model_name_split)
709
-
710
- else:
711
- for provider in config["providers"]:
712
- model_dict = get_model_dict(provider)
713
- if model in model_dict.keys():
714
- provider_rules.append(provider["provider"] + "/" + model)
715
-
716
- provider_list = []
717
- # print("provider_rules", provider_rules)
718
- for item in provider_rules:
719
- for provider in config['providers']:
720
- if "/" in item and provider['provider'] == item.split("/")[0]:
721
- new_provider = copy.deepcopy(provider)
722
- model_dict = get_model_dict(provider)
723
- model_name_split = "/".join(item.split("/")[1:])
724
- # old: new
725
- new_provider["model"] = [{model_dict[model_name_split]: model_name}]
726
- if model_name in model_dict.keys() and model_name_split == model_name:
727
- provider_list.append(new_provider)
728
-
729
- elif model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*")):
730
- provider_list.append(new_provider)
731
-
732
- # print("provider_list", provider_list)
733
- return provider_list
734
-
735
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
736
- config = app.state.config
737
- api_list = app.state.api_list
738
- api_index = api_list.index(token)
739
 
740
- model_name = request.model
741
- matching_providers = self.get_matching_providers(model_name, token)
742
  num_matching_providers = len(matching_providers)
743
 
744
  if not matching_providers:
@@ -757,6 +764,13 @@ class ModelRequestHandler:
757
  intersection = None
758
  if weights and all_providers:
759
  weight_keys = set(weights.keys())
 
 
 
 
 
 
 
760
  # 步骤 3: ���算交集
761
  intersection = all_providers.intersection(weight_keys)
762
 
@@ -769,6 +783,7 @@ class ModelRequestHandler:
769
  weighted_provider_name_list = lottery_scheduling(weights)
770
  else:
771
  weighted_provider_name_list = list(weights.keys())
 
772
 
773
  new_matching_providers = []
774
  for provider_name in weighted_provider_name_list:
@@ -786,9 +801,9 @@ class ModelRequestHandler:
786
 
787
  start_index = 0
788
  if scheduling_algorithm != "fixed_priority":
789
- async with self.locks[model_name]:
790
- self.last_provider_indices[model_name] = (self.last_provider_indices[model_name] + 1) % num_matching_providers
791
- start_index = self.last_provider_indices[model_name]
792
 
793
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
794
 
 
647
  break
648
  return selections
649
 
650
+ def get_provider_rules(model_rule, config, request_model):
651
+ provider_rules = []
652
+ if model_rule == "all":
653
+ # 如果模型名为 all,则返回所有模型
654
+ for provider in config["providers"]:
655
+ model_dict = get_model_dict(provider)
656
+ for model in model_dict.keys():
657
+ provider_rules.append(provider["provider"] + "/" + model)
658
+
659
+ elif "/" in model_rule:
660
+ if model_rule.startswith("<") and model_rule.endswith(">"):
661
+ model_rule = model_rule[1:-1]
662
+ # 处理带斜杠的模型名
663
+ for provider in config['providers']:
664
+ model_dict = get_model_dict(provider)
665
+ if model_rule in model_dict.keys():
666
+ provider_rules.append(provider['provider'] + "/" + model_rule)
667
+ else:
668
+ provider_name = model_rule.split("/")[0]
669
+ model_name_split = "/".join(model_rule.split("/")[1:])
670
+ models_list = []
671
+ for provider in config['providers']:
672
+ model_dict = get_model_dict(provider)
673
+ if provider['provider'] == provider_name:
674
+ models_list.extend(list(model_dict.keys()))
675
+ # print("models_list", models_list)
676
+ # print("model_name", model_name)
677
+ # print("model_name_split", model_name_split)
678
+ # print("model", model)
679
+
680
+ # api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
681
+ if model_name_split == "*":
682
+ if request_model in models_list:
683
+ provider_rules.append(provider_name + "/" + request_model)
684
+
685
+ # 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
686
+ for models_list_model in models_list:
687
+ if request_model.endswith("*") and models_list_model.startswith(request_model.rstrip("*")):
688
+ provider_rules.append(provider_name + "/" + models_list_model)
689
+
690
+ # api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
691
+ elif model_name_split == request_model \
692
+ or (request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
693
+ if model_name_split in models_list:
694
+ provider_rules.append(provider_name + "/" + model_name_split)
695
+
696
+ else:
697
+ for provider in config["providers"]:
698
+ model_dict = get_model_dict(provider)
699
+ if model_rule in model_dict.keys():
700
+ provider_rules.append(provider["provider"] + "/" + model_rule)
701
+
702
+ return provider_rules
703
+
704
+ def get_provider_list(provider_rules, config, request_model):
705
+ provider_list = []
706
+ # print("provider_rules", provider_rules)
707
+ for item in provider_rules:
708
+ for provider in config['providers']:
709
+ if "/" in item and provider['provider'] == item.split("/")[0]:
710
+ new_provider = copy.deepcopy(provider)
711
+ model_dict = get_model_dict(provider)
712
+ model_name_split = "/".join(item.split("/")[1:])
713
+ # old: new
714
+ new_provider["model"] = [{model_dict[model_name_split]: request_model}]
715
+ if request_model in model_dict.keys() and model_name_split == request_model:
716
+ provider_list.append(new_provider)
717
+
718
+ elif request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*")):
719
+ provider_list.append(new_provider)
720
+ return provider_list
721
+
722
+ def get_matching_providers(request_model, config, api_index):
723
+ provider_rules = []
724
+
725
+ for model_rule in config['api_keys'][api_index]['model']:
726
+ provider_rules.extend(get_provider_rules(model_rule, config, request_model))
727
+
728
+ provider_list = get_provider_list(provider_rules, config, request_model)
729
+
730
+ # print("provider_list", provider_list)
731
+ return provider_list
732
+
733
  import asyncio
734
  class ModelRequestHandler:
735
  def __init__(self):
736
  self.last_provider_indices = defaultdict(lambda: -1)
737
  self.locks = defaultdict(asyncio.Lock)
738
 
739
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
740
  config = app.state.config
 
741
  api_list = app.state.api_list
742
  api_index = api_list.index(token)
743
+
744
  if not safe_get(config, 'api_keys', api_index, 'model'):
745
  raise HTTPException(status_code=404, detail="No matching model found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
+ request_model = request.model
748
+ matching_providers = get_matching_providers(request_model, config, api_index)
749
  num_matching_providers = len(matching_providers)
750
 
751
  if not matching_providers:
 
764
  intersection = None
765
  if weights and all_providers:
766
  weight_keys = set(weights.keys())
767
+ provider_rules = []
768
+ for model_rule in weight_keys:
769
+ provider_rules.extend(get_provider_rules(model_rule, config, request_model))
770
+ provider_list = get_provider_list(provider_rules, config, request_model)
771
+ weight_keys = set([provider['provider'] for provider in provider_list])
772
+ # print("all_providers", all_providers)
773
+ # print("weights", weight_keys)
774
  # 步骤 3: ���算交集
775
  intersection = all_providers.intersection(weight_keys)
776
 
 
783
  weighted_provider_name_list = lottery_scheduling(weights)
784
  else:
785
  weighted_provider_name_list = list(weights.keys())
786
+ # print("weighted_provider_name_list", weighted_provider_name_list)
787
 
788
  new_matching_providers = []
789
  for provider_name in weighted_provider_name_list:
 
801
 
802
  start_index = 0
803
  if scheduling_algorithm != "fixed_priority":
804
+ async with self.locks[request_model]:
805
+ self.last_provider_indices[request_model] = (self.last_provider_indices[request_model] + 1) % num_matching_providers
806
+ start_index = self.last_provider_indices[request_model]
807
 
808
  auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
809
 
utils.py CHANGED
@@ -109,9 +109,9 @@ def update_config(config_data, use_config_url=False):
109
  for model in api_key.get('model'):
110
  if isinstance(model, dict):
111
  key, value = list(model.items())[0]
112
- provider_name = key.split("/")[0]
113
  if "/" in key:
114
- weights_dict.update({provider_name: int(value)})
115
  models.append(key)
116
  if isinstance(model, str):
117
  models.append(model)
 
109
  for model in api_key.get('model'):
110
  if isinstance(model, dict):
111
  key, value = list(model.items())[0]
112
+ # provider_name = key.split("/")[0]
113
  if "/" in key:
114
+ weights_dict.update({key: int(value)})
115
  models.append(key)
116
  if isinstance(model, str):
117
  models.append(model)