from fastapi import FastAPI, File, UploadFile, Form, Query, HTTPException from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional, List import os import shutil from pathlib import Path import uuid import sys import torch # Fix for 'collections' has no attribute 'Sized' issue import collections import collections.abc for typ in ['Sized', 'Iterable', 'Mapping', 'MutableMapping', 'Sequence', 'MutableSequence']: if not hasattr(collections, typ): setattr(collections, typ, getattr(collections.abc, typ)) # Add DeOldify directory to path sys.path.append('./DeOldify') torch.backends.cudnn.benchmark = False # Instead of adding models directory to path, set it as the working directory for model loading os.makedirs('models', exist_ok=True) # Create symbolic links to the model files if not os.path.exists('models/ColorizeArtistic_gen.pth'): os.symlink(os.path.abspath('./DeOldify/models/ColorizeArtistic_gen.pth'), 'models/ColorizeArtistic_gen.pth') if not os.path.exists('models/ColorizeStable_gen.pth'): os.symlink(os.path.abspath( './DeOldify/models/ColorizeStable_gen.pth'), 'models/ColorizeStable_gen.pth') if not os.path.exists('models/ColorizeVideo_gen.pth'): os.symlink(os.path.abspath( './DeOldify/models/ColorizeVideo_gen.pth'), 'models/ColorizeVideo_gen.pth') # DeOldify imports try: from deoldify.visualize import get_image_colorizer from deoldify.device_id import DeviceId from deoldify import device except Exception as e: print(f"Error importing DeOldify: {e}") # Set GPU device device.set(device=DeviceId.GPU0) app = FastAPI(title="Image Colorization API", description="API for colorizing black and white images using DeOldify") # Create directories if they don't exist os.makedirs("input_images", exist_ok=True) os.makedirs("output_images", exist_ok=True) os.makedirs("multiple_renders", exist_ok=True) class ColorizationResult(BaseModel): output_path: str render_factor: int model_type: str class MultipleColorizationResult(BaseModel): output_paths: List[str] render_factors: List[int] model_type: str @app.post("/colorize", response_model=ColorizationResult) async def colorize_image( file: UploadFile = File(...), render_factor: int = Query( 10, ge=5, le=50, description="Render factor (higher is better quality but slower)"), artistic: bool = Query( True, description="Use artistic model (True) or stable model (False)"), ): """ Colorize a black and white image with the specified render factor and model type. """ # Generate a unique filename to avoid conflicts file_id = str(uuid.uuid4()) file_extension = os.path.splitext(file.filename)[1] input_path = f"input_images/{file_id}{file_extension}" output_path = f"output_images/{file_id}_colorized{file_extension}" # Save uploaded file with open(input_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Get the appropriate colorizer based on model type colorizer = get_image_colorizer( render_factor=render_factor, artistic=artistic) # Colorize the image and save result (with watermark=False) result_path = colorizer.plot_transformed_image( path=input_path, render_factor=render_factor, compare=False, watermarked=False ) # Move the result to our desired output path shutil.copy(result_path, output_path) return ColorizationResult( output_path=output_path, render_factor=render_factor, model_type="artistic" if artistic else "stable" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Colorization failed: {str(e)}") @app.post("/colorize_multiple", response_model=MultipleColorizationResult) async def colorize_image_multiple( file: UploadFile = File(...), min_render_factor: int = Query( 5, ge=5, le=45, description="Minimum render factor"), max_render_factor: int = Query( 50, ge=10, le=50, description="Maximum render factor"), step: int = Query( 1, ge=1, le=10, description="Step size between render factors"), artistic: bool = Query( True, description="Use artistic model (True) or stable model (False)"), ): """ Colorize a black and white image with multiple render factors. """ # Generate a unique folder for this batch of renderings batch_id = str(uuid.uuid4()) batch_folder = f"multiple_renders/{batch_id}" os.makedirs(batch_folder, exist_ok=True) # Save uploaded file file_extension = os.path.splitext(file.filename)[1] input_path = f"{batch_folder}/input{file_extension}" with open(input_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Get the appropriate colorizer colorizer = get_image_colorizer( render_factor=max_render_factor, artistic=artistic) output_paths = [] render_factors = [] # Process the image with multiple render factors for render_factor in range(min_render_factor, max_render_factor + 1, step): output_file = f"{batch_folder}/colorized_{render_factor}{file_extension}" # Colorize the image with this render factor result_path = colorizer.plot_transformed_image( path=input_path, render_factor=render_factor, compare=False, watermarked=False ) # Move the result to our desired output path shutil.copy(result_path, output_file) output_paths.append(output_file) render_factors.append(render_factor) return MultipleColorizationResult( output_paths=output_paths, render_factors=render_factors, model_type="artistic" if artistic else "stable" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Multiple colorization failed: {str(e)}") @app.get("/image/{image_path:path}") async def get_image(image_path: str): """ Retrieve a colorized image by path. """ if not os.path.isfile(image_path): raise HTTPException(status_code=404, detail="Image not found") return FileResponse(image_path) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)