viarias's picture
Update app.py
b420ebd verified
import io
from PIL import Image
import base64
import os
import uuid
from typing import List
from fastapi import FastAPI, APIRouter, HTTPException
from inference import Inference
import uvicorn
import logging
from typing import Optional
from types_io import ClassificationRequest, ImageData
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
log = logging.getLogger(__name__)
def decode_base64_image(base64_str: str) -> Optional[Image.Image]:
"""
Decode a base64 encoded string into a PIL Image object.
Args:
base64_str (str): Base64 encoded image string
Returns:
Optional[Image.Image]: PIL Image object if successful, None if decoding fails
Raises:
Exception: Logged and caught internally, returns None on any error
"""
try:
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
log.error(f"Error processing image: {str(e)}")
return None
def save_images_to_disk(images: List[Image.Image], output_dir: str = "temp_images") -> List[str]:
"""
Save PIL Image objects to disk and return their file paths.
Args:
images (List[Image.Image]): List of PIL Image objects to save
output_dir (str): Directory where images will be saved (default: "temp_images")
Returns:
List[str]: List of file paths where images were saved
Raises:
Exception: If there's an error saving images to disk
"""
try:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
saved_paths = []
for i, image in enumerate(images):
if image is None:
log.warning(f"Skipping None image at index {i}")
continue
# Generate unique filename
filename = f"image_{uuid.uuid4().hex}.png"
file_path = os.path.join(output_dir, filename)
# Save image to disk
image.save(file_path, "PNG")
saved_paths.append(file_path)
log.info(f"Saved image to: {file_path}")
return saved_paths
except Exception as e:
log.error(f"Error saving images to disk: {str(e)}")
raise
app = FastAPI(title="Kimi Service", version="1.5.0")
inference = Inference()
router = APIRouter()
@app.get("/")
async def home():
return {"message": "Welcome to Kimi Service!"}
@router.post("/classify", response_model=dict)
async def classify(request: ClassificationRequest):
try:
log.info(f"Processing {len(request.images)} images")
# Decode images from base64 or load from file paths
images = []
for img_str in request.images:
img = decode_base64_image(img_str)
images.append(img)
log.info(f"Decoded {len(images)} images successfully")
# Save images and get their paths using a helper method
output_dir = os.environ.get("IMAGE_OUTPUT_DIR", "/tmp/temp_images")
saved_image_paths = save_images_to_disk(images, output_dir)
# Send images to inference
res = inference.classify_building(images, saved_image_paths)
if res is None:
raise HTTPException(status_code=500, detail="Classification failed")
return res
except ValueError as ve:
log.error(f"Validation error: {str(ve)}")
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
log.error(f"Error during classification: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
app.include_router(router)
if __name__ == "__main__":
uvicorn.run("app:app", reload=True, port=7860, host="0.0.0.0")