vverma commited on
Commit
2661513
·
1 Parent(s): 7ad0578

created api

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-39.pyc +0 -0
  2. app.py +21 -22
  3. requirements.txt +2 -1
__pycache__/app.cpython-39.pyc ADDED
Binary file (1.36 kB). View file
 
app.py CHANGED
@@ -1,33 +1,32 @@
1
- from fastapi import FastAPI
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
4
 
5
  app = FastAPI()
6
 
7
- @app.get("/")
8
- def greet_json():
9
- # Load model and processor from Hugging Face
10
- print("Loading model and processor...")
11
- processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
12
- model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')
13
 
14
- # Load all images as a batch
15
- sample_image = open_PIL_image("sample.png")
 
 
16
 
17
- # Preprocess the images
18
- preproc_image = processor.image_processor(images=[sample_image], return_tensors="pt").pixel_values
 
 
19
 
20
- # Generate and decode the tokens
21
- # NOTE: max_length default value is very small, which often results in truncated inference if not set
22
- pred_ids = model.generate(preproc_image, max_length=128)
23
  latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
24
- return {"message": "Success", "latex_preds": latex_preds}
25
 
 
26
 
27
-
28
- # Helper funtion (path to either JPEG or PNG)
29
- def open_PIL_image(image_path: str) -> Image.Image:
30
- image = Image.open(image_path)
31
- if image_path.split('.')[-1].lower() == 'png':
32
- image = Image.composite(image, PIL.Image.new('RGB', image.size, 'white'), image)
33
- return image
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
5
+ import io
6
 
7
  app = FastAPI()
8
 
9
+ # Load model and processor only once at startup
10
+ processor = TrOCRProcessor.from_pretrained('tjoab/latex_finetuned')
11
+ model = VisionEncoderDecoderModel.from_pretrained('tjoab/latex_finetuned')
 
 
 
12
 
13
+ @app.post("/predict")
14
+ async def predict_latex(file: UploadFile = File(...)):
15
+ if file.content_type not in ["image/png", "image/jpeg"]:
16
+ return JSONResponse(status_code=400, content={"error": "Only PNG and JPEG files are supported."})
17
 
18
+ # Read image contents
19
+ contents = await file.read()
20
+ image = Image.open(io.BytesIO(contents))
21
+ image = prepare_image(image)
22
 
23
+ # Preprocess and run inference
24
+ inputs = processor(images=image, return_tensors="pt").pixel_values
25
+ pred_ids = model.generate(inputs, max_length=128)
26
  latex_preds = processor.batch_decode(pred_ids, skip_special_tokens=True)
 
27
 
28
+ return {"latex_preds": latex_preds}
29
 
30
+ def prepare_image(image: Image.Image) -> Image.Image:
31
+ """Converts image to RGB if needed and flattens transparency if present."""
32
+ return Image.composite(image, Image.new('RGB', image.size, 'white'), image)
 
 
 
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ fastapi
2
  uvicorn[standard]
3
  transformers
4
  pillow
5
- torch
 
 
2
  uvicorn[standard]
3
  transformers
4
  pillow
5
+ torch
6
+ python-multipart