✨ Feature: Add feature: support wildcard matching like gpt* to match models such as gpt-3.5 and gpt-4.
Browse files- main.py +32 -24
- test/test_dict.py +12 -0
main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from log_config import logger
|
2 |
|
3 |
import re
|
|
|
4 |
import httpx
|
5 |
import secrets
|
6 |
from time import time
|
@@ -652,44 +653,51 @@ class ModelRequestHandler:
|
|
652 |
# print("model_name", model_name)
|
653 |
# print("model_name_split", model_name_split)
|
654 |
# print("model", model)
|
|
|
|
|
655 |
if model_name_split == "*":
|
656 |
if model_name in models_list:
|
657 |
provider_rules.append(provider_name)
|
658 |
-
|
659 |
-
|
660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
else:
|
662 |
-
for provider in config[
|
663 |
model_dict = get_model_dict(provider)
|
664 |
if model in model_dict.keys():
|
665 |
-
provider_rules.append(provider[
|
666 |
|
667 |
provider_list = []
|
668 |
# print("provider_rules", provider_rules)
|
669 |
for item in provider_rules:
|
670 |
for provider in config['providers']:
|
671 |
# print("provider", provider, provider['provider'] == item, item)
|
672 |
-
if "/"
|
673 |
-
|
674 |
-
model_dict = get_model_dict(provider)
|
675 |
-
if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
|
676 |
-
provider_list.append(provider)
|
677 |
-
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
|
678 |
-
elif provider['provider'] == item:
|
679 |
model_dict = get_model_dict(provider)
|
|
|
|
|
680 |
if model_name in model_dict.keys():
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
# provider_list.append(provider)
|
693 |
return provider_list
|
694 |
|
695 |
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
|
|
1 |
from log_config import logger
|
2 |
|
3 |
import re
|
4 |
+
import copy
|
5 |
import httpx
|
6 |
import secrets
|
7 |
from time import time
|
|
|
653 |
# print("model_name", model_name)
|
654 |
# print("model_name_split", model_name_split)
|
655 |
# print("model", model)
|
656 |
+
|
657 |
+
# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
|
658 |
if model_name_split == "*":
|
659 |
if model_name in models_list:
|
660 |
provider_rules.append(provider_name)
|
661 |
+
|
662 |
+
# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
|
663 |
+
for models_list_model in models_list:
|
664 |
+
if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
|
665 |
+
provider_rules.append(provider_name + "/" + models_list_model)
|
666 |
+
|
667 |
+
# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
|
668 |
+
elif model_name_split == model_name \
|
669 |
+
or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
|
670 |
+
if model_name_split in models_list:
|
671 |
+
provider_rules.append(provider_name + "/" + model_name_split)
|
672 |
+
|
673 |
else:
|
674 |
+
for provider in config["providers"]:
|
675 |
model_dict = get_model_dict(provider)
|
676 |
if model in model_dict.keys():
|
677 |
+
provider_rules.append(provider["provider"] + "/" + model)
|
678 |
|
679 |
provider_list = []
|
680 |
# print("provider_rules", provider_rules)
|
681 |
for item in provider_rules:
|
682 |
for provider in config['providers']:
|
683 |
# print("provider", provider, provider['provider'] == item, item)
|
684 |
+
if provider['provider'] == item.split("/")[0]:
|
685 |
+
new_provider = copy.deepcopy(provider)
|
|
|
|
|
|
|
|
|
|
|
686 |
model_dict = get_model_dict(provider)
|
687 |
+
# print("model_dict", model_dict)
|
688 |
+
model_name_split = "/".join(item.split("/")[1:])
|
689 |
if model_name in model_dict.keys():
|
690 |
+
if "/" in item and model_name_split == model_name:
|
691 |
+
new_provider["model"] = [{model_dict[model_name]: model_name}]
|
692 |
+
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
|
693 |
+
provider_list.append(new_provider)
|
694 |
+
|
695 |
+
elif model_name.endswith("*") and "/" in item and model_name_split.startswith(model_name.rstrip("*")):
|
696 |
+
# old: new
|
697 |
+
new_provider["model"] = [{model_dict[model_name_split]: model_name}]
|
698 |
+
provider_list.append(new_provider)
|
699 |
+
|
700 |
+
# print("provider_list", provider_list)
|
|
|
701 |
return provider_list
|
702 |
|
703 |
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
|
test/test_dict.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
a = [
|
2 |
+
{"a": 1, "b": 2, "c": 3},
|
3 |
+
{"a": 4, "b": 5, "c": 6},
|
4 |
+
{"a": 7, "b": 8, "c": 9}
|
5 |
+
]
|
6 |
+
import copy
|
7 |
+
for item in a:
|
8 |
+
new_item = copy.deepcopy(item)
|
9 |
+
new_item["a"] = 10
|
10 |
+
del new_item["b"]
|
11 |
+
# print(item)
|
12 |
+
print(a)
|