yym68686 commited on
Commit
60014c4
·
1 Parent(s): 0ebec22

✨ Feature: Add feature: support wildcard matching like gpt* to match models such as gpt-3.5 and gpt-4.

Browse files
Files changed (2) hide show
  1. main.py +32 -24
  2. test/test_dict.py +12 -0
main.py CHANGED
@@ -1,6 +1,7 @@
1
  from log_config import logger
2
 
3
  import re
 
4
  import httpx
5
  import secrets
6
  from time import time
@@ -652,44 +653,51 @@ class ModelRequestHandler:
652
  # print("model_name", model_name)
653
  # print("model_name_split", model_name_split)
654
  # print("model", model)
 
 
655
  if model_name_split == "*":
656
  if model_name in models_list:
657
  provider_rules.append(provider_name)
658
- elif model_name_split == model_name:
659
- if model_name in models_list:
660
- provider_rules.append(provider_name)
 
 
 
 
 
 
 
 
 
661
  else:
662
- for provider in config['providers']:
663
  model_dict = get_model_dict(provider)
664
  if model in model_dict.keys():
665
- provider_rules.append(provider['provider'] + "/" + model)
666
 
667
  provider_list = []
668
  # print("provider_rules", provider_rules)
669
  for item in provider_rules:
670
  for provider in config['providers']:
671
  # print("provider", provider, provider['provider'] == item, item)
672
- if "/" in item:
673
- if provider['provider'] == item.split("/")[0]:
674
- model_dict = get_model_dict(provider)
675
- if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
676
- provider_list.append(provider)
677
- # 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
678
- elif provider['provider'] == item:
679
  model_dict = get_model_dict(provider)
 
 
680
  if model_name in model_dict.keys():
681
- provider_list.append(provider)
682
- else:
683
- pass
684
-
685
- # if provider['provider'] == item:
686
- # if "/" in item:
687
- # if item.split("/")[1] == model_name:
688
- # provider_list.append(provider)
689
- # else:
690
- # model_dict = get_model_dict(provider)
691
- # if model_name in model_dict.keys():
692
- # provider_list.append(provider)
693
  return provider_list
694
 
695
  async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
 
1
  from log_config import logger
2
 
3
  import re
4
+ import copy
5
  import httpx
6
  import secrets
7
  from time import time
 
653
  # print("model_name", model_name)
654
  # print("model_name_split", model_name_split)
655
  # print("model", model)
656
+
657
+ # api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
658
  if model_name_split == "*":
659
  if model_name in models_list:
660
  provider_rules.append(provider_name)
661
+
662
+ # 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
663
+ for models_list_model in models_list:
664
+ if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
665
+ provider_rules.append(provider_name + "/" + models_list_model)
666
+
667
+ # api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
668
+ elif model_name_split == model_name \
669
+ or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
670
+ if model_name_split in models_list:
671
+ provider_rules.append(provider_name + "/" + model_name_split)
672
+
673
  else:
674
+ for provider in config["providers"]:
675
  model_dict = get_model_dict(provider)
676
  if model in model_dict.keys():
677
+ provider_rules.append(provider["provider"] + "/" + model)
678
 
679
  provider_list = []
680
  # print("provider_rules", provider_rules)
681
  for item in provider_rules:
682
  for provider in config['providers']:
683
  # print("provider", provider, provider['provider'] == item, item)
684
+ if provider['provider'] == item.split("/")[0]:
685
+ new_provider = copy.deepcopy(provider)
 
 
 
 
 
686
  model_dict = get_model_dict(provider)
687
+ # print("model_dict", model_dict)
688
+ model_name_split = "/".join(item.split("/")[1:])
689
  if model_name in model_dict.keys():
690
+ if "/" in item and model_name_split == model_name:
691
+ new_provider["model"] = [{model_dict[model_name]: model_name}]
692
+ # 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
693
+ provider_list.append(new_provider)
694
+
695
+ elif model_name.endswith("*") and "/" in item and model_name_split.startswith(model_name.rstrip("*")):
696
+ # old: new
697
+ new_provider["model"] = [{model_dict[model_name_split]: model_name}]
698
+ provider_list.append(new_provider)
699
+
700
+ # print("provider_list", provider_list)
 
701
  return provider_list
702
 
703
  async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
test/test_dict.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a = [
2
+ {"a": 1, "b": 2, "c": 3},
3
+ {"a": 4, "b": 5, "c": 6},
4
+ {"a": 7, "b": 8, "c": 9}
5
+ ]
6
+ import copy
7
+ for item in a:
8
+ new_item = copy.deepcopy(item)
9
+ new_item["a"] = 10
10
+ del new_item["b"]
11
+ # print(item)
12
+ print(a)