Spaces:
Sleeping
Sleeping
| import modal | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, Union, List, Dict, Any | |
| # Define the image with all required dependencies | |
| image = ( | |
| modal.Image.debian_slim() | |
| .pip_install([ | |
| "torch", | |
| "transformers>=4.51.0", | |
| "fastapi[standard]", | |
| "accelerate", | |
| "tokenizers" | |
| ]) | |
| ) | |
| app = modal.App("qwen-api", image=image) | |
| # Request model for the API - Maximizing token output | |
| class ChatRequest(BaseModel): | |
| message: str | |
| max_tokens: Optional[int] = 16384 # Greatly increased token limit | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.9 | |
| strip_thinking: Optional[bool] = False # Option to strip <think> tags to save tokens | |
| class ChatResponse(BaseModel): | |
| response: str | |
| tokens_used: Optional[int] = None # Make this optional | |
| input_tokens: Optional[int] = None # Track input tokens | |
| model_name: str = "Qwen/Qwen3-4B" # Include model info | |
| # Modal class to handle model loading and inference - updated for new Modal syntax | |
| class QwenModel: | |
| # Fixed: Use modal.enter() instead of __init__ for setup | |
| def setup(self): | |
| print("Loading Qwen/Qwen3-4B model...") | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "Qwen/Qwen3-4B" | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| # Load model with GPU support - use float16 for more efficient memory usage | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| def _strip_thinking_tags(self, text: str) -> str: | |
| """Strip <think> sections from the response to save tokens""" | |
| import re | |
| # Find and remove content between <think> and </think> or end of string | |
| return re.sub(r'<think>.*?(?:</think>|$)', '', text, flags=re.DOTALL) | |
| def generate_response(self, message: str, max_tokens: int = 16384, | |
| temperature: float = 0.7, top_p: float = 0.9, | |
| strip_thinking: bool = False): | |
| """Generate a response using the Qwen model""" | |
| try: | |
| import torch | |
| # Format the message for chat | |
| messages = [ | |
| {"role": "user", "content": message} | |
| ] | |
| # Apply chat template | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | |
| input_token_count = len(model_inputs.input_ids[0]) | |
| # Set parameters with very high token limits for 4B model | |
| generation_kwargs = { | |
| **model_inputs, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": True, | |
| "pad_token_id": self.tokenizer.eos_token_id, | |
| "max_new_tokens": max_tokens if max_tokens is not None else 16384, | |
| "repetition_penalty": 1.0, | |
| } | |
| print(f"Generating with settings: max_new_tokens={generation_kwargs.get('max_new_tokens')}") | |
| print(f"Input token count: {input_token_count}") | |
| # Generate response | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**generation_kwargs) | |
| # Decode the response (excluding the input tokens) | |
| input_length = model_inputs.input_ids.shape[1] | |
| response_ids = generated_ids[0][input_length:] | |
| response = self.tokenizer.decode(response_ids, skip_special_tokens=True) | |
| # Optionally strip thinking tags | |
| if strip_thinking: | |
| response = self._strip_thinking_tags(response) | |
| output_token_count = len(response_ids) | |
| print(f"Generated response with {output_token_count} tokens") | |
| return { | |
| "response": response.strip(), | |
| "tokens_used": output_token_count, | |
| "input_tokens": input_token_count, | |
| "model_name": "Qwen/Qwen3-4B" | |
| } | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| # Web endpoint - single LLM interaction endpoint | |
| # Set the function timeout to 180 seconds | |
| def chat(request: ChatRequest): | |
| """ | |
| Chat endpoint for Qwen3-4B model | |
| Example usage: | |
| curl -X POST "https://your-modal-url/" \ | |
| -H "Content-Type: application/json" \ | |
| -d '{"message": "Hello, how are you?"}' | |
| """ | |
| try: | |
| print(f"Received request: message length={len(request.message)}, max_tokens={request.max_tokens}, strip_thinking={request.strip_thinking}") | |
| # Initialize the model (this will reuse existing instance if available) | |
| model = QwenModel() | |
| # Generate response - increased function timeout at the app.function level instead | |
| result = model.generate_response.remote( | |
| message=request.message, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| strip_thinking=request.strip_thinking | |
| ) | |
| print(f"Returning response: length={len(result['response'])}, output_tokens={result.get('tokens_used')}, input_tokens={result.get('input_tokens')}") | |
| return ChatResponse( | |
| response=result["response"], | |
| tokens_used=result["tokens_used"], | |
| input_tokens=result["input_tokens"], | |
| model_name=result["model_name"] | |
| ) | |
| except Exception as e: | |
| print(f"Error in chat endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Local testing function (for Modal serve) | |
| if __name__ == "__main__": | |
| print("To deploy this app, run:") | |
| print("modal deploy qwen.py") | |
| print("\nTo run in development mode, run:") | |
| print("modal serve qwen.py") |