🐛 Bug: Fix the bug where weight polling did not check if the weight channel conforms to the request model.
Browse files
main.py
CHANGED
@@ -647,98 +647,105 @@ def lottery_scheduling(weights):
|
|
647 |
break
|
648 |
return selections
|
649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
import asyncio
|
651 |
class ModelRequestHandler:
|
652 |
def __init__(self):
|
653 |
self.last_provider_indices = defaultdict(lambda: -1)
|
654 |
self.locks = defaultdict(asyncio.Lock)
|
655 |
|
656 |
-
def
|
657 |
config = app.state.config
|
658 |
-
# api_keys_db = app.state.api_keys_db
|
659 |
api_list = app.state.api_list
|
660 |
api_index = api_list.index(token)
|
|
|
661 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
662 |
raise HTTPException(status_code=404, detail="No matching model found")
|
663 |
-
provider_rules = []
|
664 |
-
|
665 |
-
for model in config['api_keys'][api_index]['model']:
|
666 |
-
if model == "all":
|
667 |
-
# 如果模型名为 all,则返回所有模型
|
668 |
-
for provider in config["providers"]:
|
669 |
-
model_dict = get_model_dict(provider)
|
670 |
-
for model in model_dict.keys():
|
671 |
-
provider_rules.append(provider["provider"] + "/" + model)
|
672 |
-
break
|
673 |
-
if "/" in model:
|
674 |
-
if model.startswith("<") and model.endswith(">"):
|
675 |
-
model = model[1:-1]
|
676 |
-
# 处理带斜杠的模型名
|
677 |
-
for provider in config['providers']:
|
678 |
-
model_dict = get_model_dict(provider)
|
679 |
-
if model in model_dict.keys():
|
680 |
-
provider_rules.append(provider['provider'] + "/" + model)
|
681 |
-
else:
|
682 |
-
provider_name = model.split("/")[0]
|
683 |
-
model_name_split = "/".join(model.split("/")[1:])
|
684 |
-
models_list = []
|
685 |
-
for provider in config['providers']:
|
686 |
-
model_dict = get_model_dict(provider)
|
687 |
-
if provider['provider'] == provider_name:
|
688 |
-
models_list.extend(list(model_dict.keys()))
|
689 |
-
# print("models_list", models_list)
|
690 |
-
# print("model_name", model_name)
|
691 |
-
# print("model_name_split", model_name_split)
|
692 |
-
# print("model", model)
|
693 |
-
|
694 |
-
# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
|
695 |
-
if model_name_split == "*":
|
696 |
-
if model_name in models_list:
|
697 |
-
provider_rules.append(provider_name + "/" + model_name)
|
698 |
-
|
699 |
-
# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
|
700 |
-
for models_list_model in models_list:
|
701 |
-
if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
|
702 |
-
provider_rules.append(provider_name + "/" + models_list_model)
|
703 |
-
|
704 |
-
# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
|
705 |
-
elif model_name_split == model_name \
|
706 |
-
or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
|
707 |
-
if model_name_split in models_list:
|
708 |
-
provider_rules.append(provider_name + "/" + model_name_split)
|
709 |
-
|
710 |
-
else:
|
711 |
-
for provider in config["providers"]:
|
712 |
-
model_dict = get_model_dict(provider)
|
713 |
-
if model in model_dict.keys():
|
714 |
-
provider_rules.append(provider["provider"] + "/" + model)
|
715 |
-
|
716 |
-
provider_list = []
|
717 |
-
# print("provider_rules", provider_rules)
|
718 |
-
for item in provider_rules:
|
719 |
-
for provider in config['providers']:
|
720 |
-
if "/" in item and provider['provider'] == item.split("/")[0]:
|
721 |
-
new_provider = copy.deepcopy(provider)
|
722 |
-
model_dict = get_model_dict(provider)
|
723 |
-
model_name_split = "/".join(item.split("/")[1:])
|
724 |
-
# old: new
|
725 |
-
new_provider["model"] = [{model_dict[model_name_split]: model_name}]
|
726 |
-
if model_name in model_dict.keys() and model_name_split == model_name:
|
727 |
-
provider_list.append(new_provider)
|
728 |
-
|
729 |
-
elif model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*")):
|
730 |
-
provider_list.append(new_provider)
|
731 |
-
|
732 |
-
# print("provider_list", provider_list)
|
733 |
-
return provider_list
|
734 |
-
|
735 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
736 |
-
config = app.state.config
|
737 |
-
api_list = app.state.api_list
|
738 |
-
api_index = api_list.index(token)
|
739 |
|
740 |
-
|
741 |
-
matching_providers =
|
742 |
num_matching_providers = len(matching_providers)
|
743 |
|
744 |
if not matching_providers:
|
@@ -757,6 +764,13 @@ class ModelRequestHandler:
|
|
757 |
intersection = None
|
758 |
if weights and all_providers:
|
759 |
weight_keys = set(weights.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
# 步骤 3: ���算交集
|
761 |
intersection = all_providers.intersection(weight_keys)
|
762 |
|
@@ -769,6 +783,7 @@ class ModelRequestHandler:
|
|
769 |
weighted_provider_name_list = lottery_scheduling(weights)
|
770 |
else:
|
771 |
weighted_provider_name_list = list(weights.keys())
|
|
|
772 |
|
773 |
new_matching_providers = []
|
774 |
for provider_name in weighted_provider_name_list:
|
@@ -786,9 +801,9 @@ class ModelRequestHandler:
|
|
786 |
|
787 |
start_index = 0
|
788 |
if scheduling_algorithm != "fixed_priority":
|
789 |
-
async with self.locks[
|
790 |
-
self.last_provider_indices[
|
791 |
-
start_index = self.last_provider_indices[
|
792 |
|
793 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
794 |
|
|
|
647 |
break
|
648 |
return selections
|
649 |
|
650 |
+
def get_provider_rules(model_rule, config, request_model):
|
651 |
+
provider_rules = []
|
652 |
+
if model_rule == "all":
|
653 |
+
# 如果模型名为 all,则返回所有模型
|
654 |
+
for provider in config["providers"]:
|
655 |
+
model_dict = get_model_dict(provider)
|
656 |
+
for model in model_dict.keys():
|
657 |
+
provider_rules.append(provider["provider"] + "/" + model)
|
658 |
+
|
659 |
+
elif "/" in model_rule:
|
660 |
+
if model_rule.startswith("<") and model_rule.endswith(">"):
|
661 |
+
model_rule = model_rule[1:-1]
|
662 |
+
# 处理带斜杠的模型名
|
663 |
+
for provider in config['providers']:
|
664 |
+
model_dict = get_model_dict(provider)
|
665 |
+
if model_rule in model_dict.keys():
|
666 |
+
provider_rules.append(provider['provider'] + "/" + model_rule)
|
667 |
+
else:
|
668 |
+
provider_name = model_rule.split("/")[0]
|
669 |
+
model_name_split = "/".join(model_rule.split("/")[1:])
|
670 |
+
models_list = []
|
671 |
+
for provider in config['providers']:
|
672 |
+
model_dict = get_model_dict(provider)
|
673 |
+
if provider['provider'] == provider_name:
|
674 |
+
models_list.extend(list(model_dict.keys()))
|
675 |
+
# print("models_list", models_list)
|
676 |
+
# print("model_name", model_name)
|
677 |
+
# print("model_name_split", model_name_split)
|
678 |
+
# print("model", model)
|
679 |
+
|
680 |
+
# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
|
681 |
+
if model_name_split == "*":
|
682 |
+
if request_model in models_list:
|
683 |
+
provider_rules.append(provider_name + "/" + request_model)
|
684 |
+
|
685 |
+
# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
|
686 |
+
for models_list_model in models_list:
|
687 |
+
if request_model.endswith("*") and models_list_model.startswith(request_model.rstrip("*")):
|
688 |
+
provider_rules.append(provider_name + "/" + models_list_model)
|
689 |
+
|
690 |
+
# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
|
691 |
+
elif model_name_split == request_model \
|
692 |
+
or (request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
|
693 |
+
if model_name_split in models_list:
|
694 |
+
provider_rules.append(provider_name + "/" + model_name_split)
|
695 |
+
|
696 |
+
else:
|
697 |
+
for provider in config["providers"]:
|
698 |
+
model_dict = get_model_dict(provider)
|
699 |
+
if model_rule in model_dict.keys():
|
700 |
+
provider_rules.append(provider["provider"] + "/" + model_rule)
|
701 |
+
|
702 |
+
return provider_rules
|
703 |
+
|
704 |
+
def get_provider_list(provider_rules, config, request_model):
|
705 |
+
provider_list = []
|
706 |
+
# print("provider_rules", provider_rules)
|
707 |
+
for item in provider_rules:
|
708 |
+
for provider in config['providers']:
|
709 |
+
if "/" in item and provider['provider'] == item.split("/")[0]:
|
710 |
+
new_provider = copy.deepcopy(provider)
|
711 |
+
model_dict = get_model_dict(provider)
|
712 |
+
model_name_split = "/".join(item.split("/")[1:])
|
713 |
+
# old: new
|
714 |
+
new_provider["model"] = [{model_dict[model_name_split]: request_model}]
|
715 |
+
if request_model in model_dict.keys() and model_name_split == request_model:
|
716 |
+
provider_list.append(new_provider)
|
717 |
+
|
718 |
+
elif request_model.endswith("*") and model_name_split.startswith(request_model.rstrip("*")):
|
719 |
+
provider_list.append(new_provider)
|
720 |
+
return provider_list
|
721 |
+
|
722 |
+
def get_matching_providers(request_model, config, api_index):
|
723 |
+
provider_rules = []
|
724 |
+
|
725 |
+
for model_rule in config['api_keys'][api_index]['model']:
|
726 |
+
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
727 |
+
|
728 |
+
provider_list = get_provider_list(provider_rules, config, request_model)
|
729 |
+
|
730 |
+
# print("provider_list", provider_list)
|
731 |
+
return provider_list
|
732 |
+
|
733 |
import asyncio
|
734 |
class ModelRequestHandler:
|
735 |
def __init__(self):
|
736 |
self.last_provider_indices = defaultdict(lambda: -1)
|
737 |
self.locks = defaultdict(asyncio.Lock)
|
738 |
|
739 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
|
740 |
config = app.state.config
|
|
|
741 |
api_list = app.state.api_list
|
742 |
api_index = api_list.index(token)
|
743 |
+
|
744 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
745 |
raise HTTPException(status_code=404, detail="No matching model found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
746 |
|
747 |
+
request_model = request.model
|
748 |
+
matching_providers = get_matching_providers(request_model, config, api_index)
|
749 |
num_matching_providers = len(matching_providers)
|
750 |
|
751 |
if not matching_providers:
|
|
|
764 |
intersection = None
|
765 |
if weights and all_providers:
|
766 |
weight_keys = set(weights.keys())
|
767 |
+
provider_rules = []
|
768 |
+
for model_rule in weight_keys:
|
769 |
+
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
770 |
+
provider_list = get_provider_list(provider_rules, config, request_model)
|
771 |
+
weight_keys = set([provider['provider'] for provider in provider_list])
|
772 |
+
# print("all_providers", all_providers)
|
773 |
+
# print("weights", weight_keys)
|
774 |
# 步骤 3: ���算交集
|
775 |
intersection = all_providers.intersection(weight_keys)
|
776 |
|
|
|
783 |
weighted_provider_name_list = lottery_scheduling(weights)
|
784 |
else:
|
785 |
weighted_provider_name_list = list(weights.keys())
|
786 |
+
# print("weighted_provider_name_list", weighted_provider_name_list)
|
787 |
|
788 |
new_matching_providers = []
|
789 |
for provider_name in weighted_provider_name_list:
|
|
|
801 |
|
802 |
start_index = 0
|
803 |
if scheduling_algorithm != "fixed_priority":
|
804 |
+
async with self.locks[request_model]:
|
805 |
+
self.last_provider_indices[request_model] = (self.last_provider_indices[request_model] + 1) % num_matching_providers
|
806 |
+
start_index = self.last_provider_indices[request_model]
|
807 |
|
808 |
auto_retry = safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY", default=True)
|
809 |
|
utils.py
CHANGED
@@ -109,9 +109,9 @@ def update_config(config_data, use_config_url=False):
|
|
109 |
for model in api_key.get('model'):
|
110 |
if isinstance(model, dict):
|
111 |
key, value = list(model.items())[0]
|
112 |
-
provider_name = key.split("/")[0]
|
113 |
if "/" in key:
|
114 |
-
weights_dict.update({
|
115 |
models.append(key)
|
116 |
if isinstance(model, str):
|
117 |
models.append(model)
|
|
|
109 |
for model in api_key.get('model'):
|
110 |
if isinstance(model, dict):
|
111 |
key, value = list(model.items())[0]
|
112 |
+
# provider_name = key.split("/")[0]
|
113 |
if "/" in key:
|
114 |
+
weights_dict.update({key: int(value)})
|
115 |
models.append(key)
|
116 |
if isinstance(model, str):
|
117 |
models.append(model)
|