File size: 1,457 Bytes
96f6720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os

from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
from openai import OpenAI

from log_util import logger
from time_it import time_it
from util import load_prompt

load_dotenv()

IMAGE_GEN_API_BASE_URL = os.getenv('IMAGE_GEN_API_BASE_URL')
IMAGE_GEN_API_KEY = os.getenv('IMAGE_GEN_API_KEY')

IMAGE_GEN_MODEL = os.getenv('IMAGE_GEN_MODEL')
IMAGE_GEN_MAX_PROMPT_LEN = int(os.getenv('IMAGE_GEN_MAX_PROMPT_LEN'))

IMAGE_GEN_OPTIONS = {
    'response_extension': 'png',
    'width': 1024,
    'height': 1024,
    'num_inference_steps': int(os.getenv('NUM_INFERENCE_STEPS', '16')),
    'negative_prompt': '',
    'seed': -1
}

@time_it
def generate_image(prompt_file: str, input: dict) -> str:
    prompt = load_prompt(prompt_file)

    if len(prompt) > IMAGE_GEN_MAX_PROMPT_LEN:
        logger.info(f'Prompt length {len(prompt)} exceeds {IMAGE_GEN_MAX_PROMPT_LEN} characters, will be truncated.')
        prompt = prompt[:IMAGE_GEN_MAX_PROMPT_LEN]

    prompt_template = PromptTemplate.from_template(prompt)
    prompt = prompt_template.invoke(input).to_string()

    images_client = OpenAI(base_url=IMAGE_GEN_API_BASE_URL, api_key=IMAGE_GEN_API_KEY).images
    response = images_client.generate(model=IMAGE_GEN_MODEL, prompt=prompt, response_format='url', extra_body=IMAGE_GEN_OPTIONS)
    image_url = response.data[0].url
    logger.info(f'{image_url=}')
    return image_url