Merge pull request #87 from Soumil32/main
Browse files- lightrag/llm.py +69 -0
lightrag/llm.py
CHANGED
@@ -19,6 +19,8 @@ from tenacity import (
|
|
19 |
)
|
20 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
import torch
|
|
|
|
|
22 |
from .base import BaseKVStorage
|
23 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
24 |
|
@@ -554,6 +556,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|
554 |
|
555 |
return embed_text
|
556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
if __name__ == "__main__":
|
559 |
import asyncio
|
|
|
19 |
)
|
20 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
import torch
|
22 |
+
from pydantic import BaseModel, Field
|
23 |
+
from typing import List, Dict, Callable, Any
|
24 |
from .base import BaseKVStorage
|
25 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
26 |
|
|
|
556 |
|
557 |
return embed_text
|
558 |
|
559 |
+
class Model(BaseModel):
|
560 |
+
"""
|
561 |
+
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
562 |
+
|
563 |
+
Attributes:
|
564 |
+
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
565 |
+
The function should take any argument and return a string.
|
566 |
+
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
567 |
+
This could include parameters such as the model name, API key, etc.
|
568 |
+
|
569 |
+
Example usage:
|
570 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
571 |
+
|
572 |
+
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
573 |
+
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
574 |
+
"""
|
575 |
+
|
576 |
+
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
|
577 |
+
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
|
578 |
+
|
579 |
+
class Config:
|
580 |
+
arbitrary_types_allowed = True
|
581 |
+
|
582 |
+
|
583 |
+
class MultiModel():
|
584 |
+
"""
|
585 |
+
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
586 |
+
Could also be used for spliting across diffrent models or providers.
|
587 |
+
|
588 |
+
Attributes:
|
589 |
+
models (List[Model]): A list of language models to be used.
|
590 |
+
|
591 |
+
Usage example:
|
592 |
+
```python
|
593 |
+
models = [
|
594 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
595 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
596 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
597 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
598 |
+
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
599 |
+
]
|
600 |
+
multi_model = MultiModel(models)
|
601 |
+
rag = LightRAG(
|
602 |
+
llm_model_func=multi_model.llm_model_func
|
603 |
+
/ ..other args
|
604 |
+
)
|
605 |
+
```
|
606 |
+
"""
|
607 |
+
def __init__(self, models: List[Model]):
|
608 |
+
self._models = models
|
609 |
+
self._current_model = 0
|
610 |
+
|
611 |
+
def _next_model(self):
|
612 |
+
self._current_model = (self._current_model + 1) % len(self._models)
|
613 |
+
return self._models[self._current_model]
|
614 |
+
|
615 |
+
async def llm_model_func(
|
616 |
+
self,
|
617 |
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
618 |
+
) -> str:
|
619 |
+
kwargs.pop("model", None) # stop from overwriting the custom model name
|
620 |
+
next_model = self._next_model()
|
621 |
+
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
|
622 |
+
|
623 |
+
return await next_model.gen_func(
|
624 |
+
**args
|
625 |
+
)
|
626 |
|
627 |
if __name__ == "__main__":
|
628 |
import asyncio
|