grademe / app.py
vverma
fixed requirements.txt
7ad0578
raw
history blame
1.26 kB
from fastapi import FastAPI
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
app = FastAPI()
@app.get("/")
def greet_json():
# Load model and processor from Hugging Face
print("Loading model and processor...")
processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')
# Load all images as a batch
sample_image = open_PIL_image("sample.png")
# Preprocess the images
preproc_image = processor.image_processor(images=[sample_image], return_tensors="pt").pixel_values
# Generate and decode the tokens
# NOTE: max_length default value is very small, which often results in truncated inference if not set
pred_ids = model.generate(preproc_image, max_length=128)
latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
return {"message": "Success", "latex_preds": latex_preds}
# Helper funtion (path to either JPEG or PNG)
def open_PIL_image(image_path: str) -> Image.Image:
image = Image.open(image_path)
if image_path.split('.')[-1].lower() == 'png':
image = Image.composite(image, PIL.Image.new('RGB', image.size, 'white'), image)
return image