bogota_land_space / classifier.py
viarias's picture
Update classifier.py
4b948e1 verified
raw
history blame
7.61 kB
from typing import List
import logging
from model import Model
from PIL import Image
from types_io import ImageData
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
LAND_USE_PROMPT = f"""
Task: You are a structured image analysis agent. Given an image of a building front, generate a comprehensive tag list and provide the thinking process for an image classification systems.
Requirement: The output should generate a minimum of 3 categories for each input.
Confidence: Confidence score for each category, ranging from 0 (very confident) to 1 (little/none confident).
Categories :
- Residenciales: Buildings intended for housing - Houses, PH Buildings, Condominiums.
- Comerciales1: Refers to the storage, distribution, or exchange of products, goods, or services with a commercial interest.
- Comerciales2: Buildings where activities aimed at providing services are carried out.
- Comerciales3: Buildings used for artisanal activities where raw materials are transformed on a local scale.
- Comerciales4: Hotels, Motels, and Restaurants.
- Comerciales5: Operational offices and warehouses.
- Centros_Comerciales: Commercial premises located on properties of one or several buildings.
- Bodegas: Buildings in warehouse-type constructions dedicated to commercial, industrial, or storage activities.
- Parqueaderos: Buildings designed for vehicle parking.
- Dotacionales1: Buildings where activities aimed at the welfare or service of a community are carried out.
- Dotacionales2: Buildings designed to carry out educational or training activities.
- Dotacionales3: Buildings with the necessary infrastructure to provide surgical and/or hospitalization services.
- Dotacionales4: Buildings for religious worship owned by communities or religious congregations.
- Dotacionales5: Theaters, cinemas, swimming pools, museums, sports, events, or shows.
- Especiales: Military administrative areas, cemeteries, airport runways.
- Moles: Large buildings in height (>4 floors) or area (>10,000 m²), usually under construction.
- Rurales: Sheds, kiosks, shelters, barns, stables, silos, etc.
- Mixto1: (Residencial + Comercial1) Housing and commercial premises.
- Mixto2: (Residencial + Comercial2) Housing and offices.
- Mixto3: (Comercial1 + Comercial2) Commercial premises and offices.
Return the information in the following JSON schema:
{ImageData.model_json_schema()}
"""
class Classifier:
def __init__(self, MAX_NEW_TOKENS: int = 1024):
self.max_new_tokens = MAX_NEW_TOKENS
logger.info("Initializing Classifier")
logger.info("Loading model...")
self.model = Model.load_model()
logger.info("Loading processor...")
self.processor = Model.load_processor()
logger.info("Classifier initialization complete")
logger.info("Setting up image data generator...")
def get_response(self, images: List[Image.Image], saved_image_paths: List[str] = None) -> dict:
logger.info(f"Processing classification request for {len(images)} images")
logger.info("Loading and preprocessing images...")
images = self.get_input_tensor(images)
logger.debug("Successfully preprocessed images")
logger.info("Preparing input messages...")
messages = self.prepare_messages(saved_image_paths)
response = self.generate_model_response(images, messages)
return {"output": response}
def get_input_tensor(self, images: List[Image.Image]) -> List[Image.Image]:
"""
Preprocess a list of PIL images.
Args:
images (List[Image.Image]): List of PIL images to be processed.
Returns:
List[Image.Image]: List of preprocessed images ready for classification.
"""
if not images:
raise ValueError("No images provided for classification.")
logger.info(f"Preprocessing {len(images)} images...")
processed_images = []
for idx, img in enumerate(images):
logger.debug(f"Processing image at index: {idx}")
try:
img = self.resize_image(img)
processed_images.append(img)
logger.debug(f"Successfully processed image at index: {idx}")
except Exception as e:
logger.error(f"Error processing image at index {idx}: {str(e)}")
raise
return processed_images
def generate_model_response(self, images: List[Image.Image], messages: List[dict]) -> str:
"""
Generate response from the model.
Args:
images (List[Image.Image]): List of preprocessed images.
messages (List[dict]): Messages for the processor.
Returns:
str: Decoded response from the model.
"""
logger.info("Applying chat template...")
try:
# Get the text as string first, then let outlines handle tokenization
text = self.processor.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
logger.info(f"Text length: {len(text)} characters")
inputs = self.processor(images=images, text=text, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
except Exception as e:
logger.error(f"Error applying chat template: {str(e)}")
raise
logger.info("Generating response...")
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, temperature=0.1)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
logger.debug("Successfully generated response")
return response
@staticmethod
def resize_image(image: Image.Image, max_size: int = 224) -> Image.Image:
"""
Resize an image while maintaining aspect ratio.
Args:
image: PIL Image object to resize
max_size: Maximum dimension (width or height) of the output image
Returns:
PIL Image: Resized image with maintained aspect ratio
"""
# Get current dimensions
width, height = image.size
# Calculate scaling factor to fit within max_size
scale = min(max_size / width, max_size / height)
# Only resize if image is larger than max_size
if scale < 1:
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize(
(new_width, new_height),
Image.Resampling.LANCZOS
)
return image
@staticmethod
def prepare_messages(saved_image_paths: List[str]) -> List[dict]:
"""
Prepare messages for the processor.
Args:
saved_image_paths (List[str]): List of paths to saved images.
classification_prompt (str): The prompt for classification.
Returns:
List[dict]: Messages for the processor.
"""
return [
{
"role": "user",
"content": [
{"type": "image", "image": image_path} for image_path in saved_image_paths
] + [{"type": "text", "text": LAND_USE_PROMPT}],
},
]