|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from PIL import Image |
|
import torch |
|
|
|
class ImagePromptModel: |
|
def __init__(self): |
|
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
def generate_prompt(self, image_path): |
|
image = Image.open(image_path).convert('RGB') |
|
inputs = self.processor(images=image, return_tensors="pt") |
|
out = self.model.generate(**inputs) |
|
return self.processor.decode(out[0], skip_special_tokens=True) |
|
|