zrguo commited on
Commit
a83e6f8
·
2 Parent(s): 5c26642 02fb41e

Merge pull request #87 from Soumil32/main

Browse files
Files changed (1) hide show
  1. lightrag/llm.py +69 -0
lightrag/llm.py CHANGED
@@ -19,6 +19,8 @@ from tenacity import (
19
  )
20
  from transformers import AutoTokenizer, AutoModelForCausalLM
21
  import torch
 
 
22
  from .base import BaseKVStorage
23
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
24
 
@@ -554,6 +556,73 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
554
 
555
  return embed_text
556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
  if __name__ == "__main__":
559
  import asyncio
 
19
  )
20
  from transformers import AutoTokenizer, AutoModelForCausalLM
21
  import torch
22
+ from pydantic import BaseModel, Field
23
+ from typing import List, Dict, Callable, Any
24
  from .base import BaseKVStorage
25
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
26
 
 
556
 
557
  return embed_text
558
 
559
+ class Model(BaseModel):
560
+ """
561
+ This is a Pydantic model class named 'Model' that is used to define a custom language model.
562
+
563
+ Attributes:
564
+ gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
565
+ The function should take any argument and return a string.
566
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
567
+ This could include parameters such as the model name, API key, etc.
568
+
569
+ Example usage:
570
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
571
+
572
+ In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
573
+ The 'kwargs' dictionary contains the model name and API key to be passed to the function.
574
+ """
575
+
576
+ gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
577
+ kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
578
+
579
+ class Config:
580
+ arbitrary_types_allowed = True
581
+
582
+
583
+ class MultiModel():
584
+ """
585
+ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
586
+ Could also be used for spliting across diffrent models or providers.
587
+
588
+ Attributes:
589
+ models (List[Model]): A list of language models to be used.
590
+
591
+ Usage example:
592
+ ```python
593
+ models = [
594
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
595
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
596
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
597
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
598
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
599
+ ]
600
+ multi_model = MultiModel(models)
601
+ rag = LightRAG(
602
+ llm_model_func=multi_model.llm_model_func
603
+ / ..other args
604
+ )
605
+ ```
606
+ """
607
+ def __init__(self, models: List[Model]):
608
+ self._models = models
609
+ self._current_model = 0
610
+
611
+ def _next_model(self):
612
+ self._current_model = (self._current_model + 1) % len(self._models)
613
+ return self._models[self._current_model]
614
+
615
+ async def llm_model_func(
616
+ self,
617
+ prompt, system_prompt=None, history_messages=[], **kwargs
618
+ ) -> str:
619
+ kwargs.pop("model", None) # stop from overwriting the custom model name
620
+ next_model = self._next_model()
621
+ args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
622
+
623
+ return await next_model.gen_func(
624
+ **args
625
+ )
626
 
627
  if __name__ == "__main__":
628
  import asyncio