|
from fastapi import FastAPI |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
|
|
print("Loading model and processor...") |
|
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned') |
|
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned') |
|
|
|
|
|
sample_image = open_PIL_image("sample.png") |
|
|
|
|
|
preproc_image = processor.image_processor(images=[sample_image], return_tensors="pt").pixel_values |
|
|
|
|
|
|
|
pred_ids = model.generate(preproc_image, max_length=128) |
|
latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
return {"message": "Success", "latex_preds": latex_preds} |
|
|
|
|
|
|
|
|
|
def open_PIL_image(image_path: str) -> Image.Image: |
|
image = Image.open(image_path) |
|
if image_path.split('.')[-1].lower() == 'png': |
|
image = Image.composite(image, PIL.Image.new('RGB', image.size, 'white'), image) |
|
return image |
|
|