from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image, ImageDraw, ImageFont
import re
def draw_circle(draw, center, radius=10, width=2, outline_color=(0, 255, 0), is_fill=False, bg_color=(0, 255, 0), transparency=80):
    
    x1 = center[0] - radius
    y1 = center[1] - radius
    x2 = center[0] + radius
    y2 = center[1] + radius
    bbox = (x1, y1, x2, y2)
    
    if is_fill:
        
        alpha = int((1 - transparency / 100) * 255)
        
        fill_color = tuple(bg_color) + (alpha,)
        draw.ellipse(bbox, width=width, outline=outline_color, fill=fill_color)
    else:
        draw.ellipse(bbox, width=width, outline=outline_color)
def draw_point(draw, center, radius1=3, radius2=6, color=(0, 255, 0)):
    draw_circle(draw, center, radius=radius1, outline_color=color)
    draw_circle(draw, center, radius=radius2, outline_color=color)
def draw_rectangle(draw, box_coords, width=2, outline_color=(0, 255, 0), is_fill=False, bg_color=(0, 255, 0), transparency=80):  
    if is_fill:
        
        alpha = int((1 - transparency / 100) * 255)
        
        fill_color = tuple(bg_color) + (alpha,)
        draw.rectangle(box_coords, width=width, outline=outline_color, fill=fill_color)
    else:
        draw.rectangle(box_coords, width=width, outline=outline_color)
def draw(path, out_path, response):
    img = Image.open(path).convert("RGB")
    draw = ImageDraw.Draw(img)
    box_coords = re.findall(r"<box>(.*?)</box>", response)
    for box in box_coords:
        try:
            x1, y1, x2, y2 = box.replace("(", "").replace(")", "").split(",")
            x1, y1, x2, y2 = float(x1) * img.width/1000, float(y1) * img.height/1000, float(x2) * img.width/1000, float(y2) * img.height/1000
            draw_rectangle(draw, (x1, y1, x2, y2))
        except:
            print("There were some errors while parsing the bounding box.")
    point_coords = re.findall(r"<point>(.*?)</point>", response)
    for point in point_coords:
        try:
            x1, y1 = point.replace("(", "").replace(")", "").split(",")
            x1, y1 = float(x1) * img.width/1000, float(y1) * img.height/1000
            draw_point(draw, (x1, y1))
        except:
            print("There were some errors while parsing the bounding point.")
    img.save(out_path)
def load_model_and_tokenizer(path, device):
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(path, device_map=device, trust_remote_code=True).eval()
    return model, tokenizer
def infer(model, tokenizer, image_path, text):
    query = tokenizer.from_list_format([
        {'image': image_path},
        {'text': text},
    ])
    response, history = model.chat(tokenizer, query=query, history=None)
    return response
if __name__ == "__main__":
    device = "cuda:0"
    model_path = "<your_model_path>"
    model, tokenizer = load_model_and_tokenizer(model_path, device)
    while True:
        image_path = input("image path >>>>> ")
        if image_path == "stop":
            break
        query = input("Human:")
        if query == "stop":
            break
        response = infer(model, tokenizer, image_path, query)
        draw(image_path, "1.jpg", response)