File size: 1,191 Bytes
96f6720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os

from dotenv import load_dotenv
from langchain_openai import ChatOpenAI

from log_util import logger
from time_it import time_it_async

load_dotenv()

LLAMA_API_BASE_URL = os.getenv('LLAMA_API_BASE_URL')
LLAMA_API_KEY = os.getenv('LLAMA_API_KEY')

LLAMA_TEXT_MODEL = os.getenv('LLAMA_TEXT_MODEL')
LLAMA_TEXT_MAX_TOKENS = int(os.getenv('LLAMA_TEXT_MAX_TOKENS'))

LLAMA_VISION_MODEL = os.getenv('LLAMA_VISION_MODEL')
LLAMA_VISION_MAX_TOKENS = int(os.getenv('LLAMA_VISION_MAX_TOKENS'))

TEMPERATURE = float(os.getenv('TEMPERATURE', '0.2'))

TOKEN_FACTOR_FOR_PROMPT = 0.75

def get_text_llm() -> ChatOpenAI:
    return _get_llm(LLAMA_TEXT_MODEL, LLAMA_TEXT_MAX_TOKENS)

def get_vision_llm() -> ChatOpenAI:
    return _get_llm(LLAMA_VISION_MODEL, LLAMA_VISION_MAX_TOKENS)

def _get_llm(model: str, max_tokens: int) -> ChatOpenAI:
    return ChatOpenAI(base_url=LLAMA_API_BASE_URL, api_key=LLAMA_API_KEY,
                      model=model, max_tokens=int(max_tokens*TOKEN_FACTOR_FOR_PROMPT), temperature=TEMPERATURE)

@time_it_async
async def invoke_llm_async(chain, input: dict | None = {}):
    response = await chain.ainvoke(input)
    logger.info(f'{response=}')
    return response