Spaces:
Running
Running
File size: 1,191 Bytes
96f6720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from log_util import logger
from time_it import time_it_async
load_dotenv()
LLAMA_API_BASE_URL = os.getenv('LLAMA_API_BASE_URL')
LLAMA_API_KEY = os.getenv('LLAMA_API_KEY')
LLAMA_TEXT_MODEL = os.getenv('LLAMA_TEXT_MODEL')
LLAMA_TEXT_MAX_TOKENS = int(os.getenv('LLAMA_TEXT_MAX_TOKENS'))
LLAMA_VISION_MODEL = os.getenv('LLAMA_VISION_MODEL')
LLAMA_VISION_MAX_TOKENS = int(os.getenv('LLAMA_VISION_MAX_TOKENS'))
TEMPERATURE = float(os.getenv('TEMPERATURE', '0.2'))
TOKEN_FACTOR_FOR_PROMPT = 0.75
def get_text_llm() -> ChatOpenAI:
return _get_llm(LLAMA_TEXT_MODEL, LLAMA_TEXT_MAX_TOKENS)
def get_vision_llm() -> ChatOpenAI:
return _get_llm(LLAMA_VISION_MODEL, LLAMA_VISION_MAX_TOKENS)
def _get_llm(model: str, max_tokens: int) -> ChatOpenAI:
return ChatOpenAI(base_url=LLAMA_API_BASE_URL, api_key=LLAMA_API_KEY,
model=model, max_tokens=int(max_tokens*TOKEN_FACTOR_FOR_PROMPT), temperature=TEMPERATURE)
@time_it_async
async def invoke_llm_async(chain, input: dict | None = {}):
response = await chain.ainvoke(input)
logger.info(f'{response=}')
return response
|