FastAPI + LLM Streaming Integration
Build high-performance, async LLM APIs with real-time streaming capabilities
Table of Contents
Overview
Why FastAPI is ideal for LLM integrations
FastAPI is a modern Python web framework that excels at building high-performance APIs with async support, automatic validation, and built-in documentation. It's particularly well-suited for LLM integrations due to its streaming capabilities and async-first design.
Key Features
- • Native async/await support
- • Automatic request validation
- • Built-in OpenAPI docs
- • WebSocket & SSE support
LLM Benefits
- • Stream tokens in real-time
- • Handle concurrent requests
- • Async LLM API calls
- • Efficient resource usage
Setup & Configuration
Installing FastAPI and required dependencies
Installation
# Core dependencies pip install fastapi uvicorn[standard] httpx python-multipart # Additional dependencies for production pip install gunicorn python-jose[cryptography] passlib[bcrypt] pip install redis celery flower # For background tasks pip install prometheus-fastapi-instrumentator # Monitoring # For development pip install pytest pytest-asyncio black isort
Project Structure
llm-api/ ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI app │ ├── config.py # Configuration │ ├── models.py # Pydantic models │ ├── routers/ │ │ ├── __init__.py │ │ ├── chat.py # Chat endpoints │ │ ├── streaming.py # Streaming endpoints │ │ └── websocket.py # WebSocket endpoints │ ├── services/ │ │ ├── __init__.py │ │ ├── llm.py # LLM service │ │ └── auth.py # Authentication │ ├── middleware/ │ │ ├── __init__.py │ │ ├── rate_limit.py │ │ └── logging.py │ └── utils/ │ ├── __init__.py │ └── helpers.py ├── tests/ ├── docker-compose.yml ├── Dockerfile └── requirements.txt
Basic FastAPI App
# app/main.py from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import httpx from app.config import settings from app.routers import chat, streaming, websocket # Lifespan context manager for startup/shutdown @asynccontextmanager async def lifespan(app: FastAPI): # Startup app.state.http_client = httpx.AsyncClient(timeout=30.0) yield # Shutdown await app.state.http_client.aclose() # Create FastAPI app app = FastAPI( title="LLM Streaming API", description="High-performance LLM API with streaming support", version="1.0.0", lifespan=lifespan ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=settings.ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include routers app.include_router(chat.router, prefix="/api/v1/chat", tags=["chat"]) app.include_router(streaming.router, prefix="/api/v1/stream", tags=["streaming"]) app.include_router(websocket.router, prefix="/ws", tags=["websocket"]) @app.get("/") async def root(): return { "message": "LLM Streaming API", "docs": "/docs", "health": "/health" } @app.get("/health") async def health_check(): return { "status": "healthy", "version": "1.0.0" }
Configuration Management
# app/config.py from pydantic_settings import BaseSettings from typing import List, Optional import os class Settings(BaseSettings): # API Settings API_V1_STR: str = "/api/v1" PROJECT_NAME: str = "LLM Streaming API" # CORS ALLOWED_ORIGINS: List[str] = ["http://localhost:3000", "https://yourdomain.com"] # LLM Provider Settings LLM_API_BASE_URL: str = "https://api.parrotrouter.com/v1" LLM_API_KEY: str LLM_MODEL: str = "gpt-3.5-turbo" LLM_MAX_TOKENS: int = 2000 LLM_TEMPERATURE: float = 0.7 # Redis REDIS_URL: str = "redis://localhost:6379" # Authentication SECRET_KEY: str ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Rate Limiting RATE_LIMIT_PER_MINUTE: int = 60 RATE_LIMIT_PER_HOUR: int = 1000 # Monitoring ENABLE_METRICS: bool = True LOG_LEVEL: str = "INFO" class Config: env_file = ".env" case_sensitive = True settings = Settings()
Basic LLM Integration
Implementing standard LLM API calls with FastAPI
LLM Service
# app/services/llm.py import httpx from typing import Dict, Any, Optional, AsyncGenerator import json from app.config import settings import logging logger = logging.getLogger(__name__) class LLMService: """Service for interacting with LLM APIs""" def __init__(self, http_client: httpx.AsyncClient): self.http_client = http_client self.base_url = settings.LLM_API_BASE_URL self.headers = { "Authorization": f"Bearer {settings.LLM_API_KEY}", "Content-Type": "application/json" } async def complete(self, prompt: str, model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False) -> Dict[str, Any]: """ Make a completion request to the LLM API """ payload = { "model": model or settings.LLM_MODEL, "messages": [{"role": "user", "content": prompt}], "temperature": temperature or settings.LLM_TEMPERATURE, "max_tokens": max_tokens or settings.LLM_MAX_TOKENS, "stream": stream } try: if stream: return await self._stream_completion(payload) else: response = await self.http_client.post( f"{self.base_url}/chat/completions", headers=self.headers, json=payload ) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: logger.error(f"LLM API error: {e.response.status_code} - {e.response.text}") raise except Exception as e: logger.error(f"Unexpected error calling LLM: {e}") raise async def _stream_completion(self, payload: Dict[str, Any]) -> AsyncGenerator[str, None]: """ Stream completion from LLM API """ async with self.http_client.stream( "POST", f"{self.base_url}/chat/completions", headers=self.headers, json=payload ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] if data == "[DONE]": break try: chunk = json.loads(data) if chunk["choices"][0]["delta"].get("content"): yield chunk["choices"][0]["delta"]["content"] except json.JSONDecodeError: logger.warning(f"Failed to parse streaming chunk: {data}") continue
Pydantic Models
# app/models.py from pydantic import BaseModel, Field, validator from typing import Optional, List, Literal from datetime import datetime class ChatMessage(BaseModel): role: Literal["system", "user", "assistant"] content: str class ChatRequest(BaseModel): messages: List[ChatMessage] model: Optional[str] = Field(None, description="Model to use") temperature: Optional[float] = Field(None, ge=0, le=2) max_tokens: Optional[int] = Field(None, ge=1, le=4000) stream: bool = Field(False, description="Enable streaming response") @validator('messages') def validate_messages(cls, v): if not v: raise ValueError("Messages cannot be empty") return v class ChatResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[dict] usage: dict class StreamingChatRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=10000) system_prompt: Optional[str] = None temperature: Optional[float] = Field(0.7, ge=0, le=2) max_tokens: Optional[int] = Field(1000, ge=1, le=4000) class ErrorResponse(BaseModel): error: str detail: Optional[str] = None timestamp: datetime = Field(default_factory=datetime.utcnow)
Chat Router
# app/routers/chat.py from fastapi import APIRouter, HTTPException, Request, Depends from app.models import ChatRequest, ChatResponse, ErrorResponse from app.services.llm import LLMService from typing import Annotated router = APIRouter() async def get_llm_service(request: Request) -> LLMService: """Dependency to get LLM service""" return LLMService(request.app.state.http_client) @router.post("/completions", response_model=ChatResponse, responses={ 400: {"model": ErrorResponse}, 500: {"model": ErrorResponse} }) async def create_chat_completion( request: ChatRequest, llm_service: Annotated[LLMService, Depends(get_llm_service)] ): """ Create a chat completion - **messages**: List of chat messages - **model**: Optional model override - **temperature**: Sampling temperature (0-2) - **max_tokens**: Maximum tokens to generate """ try: # Convert messages to the format expected by LLM prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages]) result = await llm_service.complete( prompt=prompt, model=request.model, temperature=request.temperature, max_tokens=request.max_tokens, stream=False ) return result except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, detail=f"LLM API error: {e.response.text}" ) except Exception as e: raise HTTPException( status_code=500, detail=f"Internal server error: {str(e)}" )
Streaming with Server-Sent Events (SSE)
Implementing real-time token streaming using SSE
SSE Streaming Endpoint
# app/routers/streaming.py from fastapi import APIRouter, Request, Depends from fastapi.responses import StreamingResponse from app.models import StreamingChatRequest from app.services.llm import LLMService from typing import AsyncGenerator import json import asyncio router = APIRouter() async def get_llm_service(request: Request) -> LLMService: return LLMService(request.app.state.http_client) @router.post("/chat") async def stream_chat( request: StreamingChatRequest, llm_service: Annotated[LLMService, Depends(get_llm_service)] ): """ Stream chat responses using Server-Sent Events """ async def generate_sse() -> AsyncGenerator[str, None]: try: # Send initial event yield f"data: {json.dumps({'type': 'start', 'message': 'Starting generation'})}\n\n" # Prepare the prompt full_prompt = request.prompt if request.system_prompt: full_prompt = f"System: {request.system_prompt}\n\nUser: {request.prompt}" # Stream tokens from LLM token_count = 0 async for token in llm_service.complete( prompt=full_prompt, temperature=request.temperature, max_tokens=request.max_tokens, stream=True ): token_count += 1 yield f"data: {json.dumps({'type': 'token', 'content': token, 'index': token_count})}\n\n" # Small delay to prevent overwhelming the client await asyncio.sleep(0.01) # Send completion event yield f"data: {json.dumps({'type': 'complete', 'token_count': token_count})}\n\n" except Exception as e: # Send error event yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" finally: # Send done event yield f"data: [DONE]\n\n" return StreamingResponse( generate_sse(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" # Disable proxy buffering } ) @router.get("/test") async def test_stream(): """ Test SSE endpoint that streams numbers """ async def generate(): for i in range(10): yield f"data: {json.dumps({'number': i, 'timestamp': str(datetime.utcnow())})}\n\n" await asyncio.sleep(1) yield "data: [DONE]\n\n" return StreamingResponse( generate(), media_type="text/event-stream" )
Frontend SSE Consumer
// Frontend code to consume SSE stream class SSEClient { private eventSource: EventSource | null = null; async streamChat(prompt: string, onToken: (token: string) => void) { const response = await fetch('/api/v1/stream/chat', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ prompt }) }); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } const reader = response.body?.getReader(); const decoder = new TextDecoder(); if (!reader) return; while (true) { const { done, value } = await reader.read(); if (done) break; const chunk = decoder.decode(value); const lines = chunk.split('\n'); for (const line of lines) { if (line.startsWith('data: ')) { const data = line.slice(6); if (data === '[DONE]') { return; } try { const parsed = JSON.parse(data); if (parsed.type === 'token') { onToken(parsed.content); } } catch (e) { console.error('Failed to parse SSE data:', e); } } } } } } // Usage const client = new SSEClient(); let fullResponse = ''; await client.streamChat('Tell me a story', (token) => { fullResponse += token; console.log('Token:', token); });
Advanced SSE Features
# Advanced SSE implementation with retry and heartbeat from datetime import datetime import uuid class SSEResponse: """Helper class for SSE responses""" def __init__(self, retry: int = 3000): self.retry = retry def format_sse(self, data: str, event: Optional[str] = None, id: Optional[str] = None) -> str: """Format data as SSE""" lines = [] if id: lines.append(f"id: {id}") if event: lines.append(f"event: {event}") if self.retry: lines.append(f"retry: {self.retry}") # Split data by newlines and prefix each line for line in data.splitlines(): lines.append(f"data: {line}") return "\n".join(lines) + "\n\n" @router.post("/chat/advanced") async def advanced_stream_chat(request: StreamingChatRequest): """Advanced SSE streaming with heartbeat and metadata""" sse = SSEResponse(retry=5000) request_id = str(uuid.uuid4()) async def generate(): try: # Send metadata metadata = { "request_id": request_id, "model": settings.LLM_MODEL, "timestamp": datetime.utcnow().isoformat() } yield sse.format_sse( json.dumps(metadata), event="metadata", id=request_id ) # Start heartbeat task heartbeat_task = asyncio.create_task(heartbeat_generator()) # Stream LLM response token_buffer = [] async for token in llm_service.complete( prompt=request.prompt, temperature=request.temperature, max_tokens=request.max_tokens, stream=True ): token_buffer.append(token) # Send tokens in batches for efficiency if len(token_buffer) >= 5: yield sse.format_sse( json.dumps({ "tokens": token_buffer, "partial": "".join(token_buffer) }), event="tokens" ) token_buffer = [] # Send remaining tokens if token_buffer: yield sse.format_sse( json.dumps({ "tokens": token_buffer, "partial": "".join(token_buffer) }), event="tokens" ) # Cancel heartbeat heartbeat_task.cancel() # Send completion yield sse.format_sse( json.dumps({"status": "complete"}), event="done" ) except Exception as e: yield sse.format_sse( json.dumps({ "error": str(e), "type": type(e).__name__ }), event="error" ) async def heartbeat_generator(): """Send periodic heartbeat to keep connection alive""" while True: await asyncio.sleep(30) yield sse.format_sse( json.dumps({"type": "heartbeat"}), event="heartbeat" ) return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Request-ID": request_id } )
WebSocket Implementation
Building real-time bidirectional chat with WebSockets
WebSocket Chat Handler
# app/routers/websocket.py from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from app.services.llm import LLMService import json import asyncio from typing import Dict, Set import logging logger = logging.getLogger(__name__) router = APIRouter() # Connection manager for handling multiple WebSocket connections class ConnectionManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.user_sessions: Dict[str, list] = {} # Store chat history async def connect(self, websocket: WebSocket, client_id: str): await websocket.accept() self.active_connections[client_id] = websocket if client_id not in self.user_sessions: self.user_sessions[client_id] = [] logger.info(f"Client {client_id} connected") def disconnect(self, client_id: str): if client_id in self.active_connections: del self.active_connections[client_id] logger.info(f"Client {client_id} disconnected") async def send_message(self, client_id: str, message: dict): if client_id in self.active_connections: await self.active_connections[client_id].send_json(message) async def broadcast(self, message: dict, exclude: Set[str] = None): exclude = exclude or set() for client_id, connection in self.active_connections.items(): if client_id not in exclude: await connection.send_json(message) def add_to_history(self, client_id: str, role: str, content: str): if client_id in self.user_sessions: self.user_sessions[client_id].append({ "role": role, "content": content, "timestamp": datetime.utcnow().isoformat() }) # Keep only last 50 messages self.user_sessions[client_id] = self.user_sessions[client_id][-50:] manager = ConnectionManager() @router.websocket("/chat/{client_id}") async def websocket_chat( websocket: WebSocket, client_id: str, llm_service: LLMService = Depends(lambda: LLMService(httpx.AsyncClient())) ): """ WebSocket endpoint for real-time chat """ await manager.connect(websocket, client_id) try: # Send welcome message await manager.send_message(client_id, { "type": "system", "message": "Connected to LLM chat", "client_id": client_id }) while True: # Receive message from client data = await websocket.receive_json() message_type = data.get("type", "message") if message_type == "message": await handle_chat_message(client_id, data, llm_service) elif message_type == "typing": # Broadcast typing indicator await manager.broadcast({ "type": "typing", "client_id": client_id }, exclude={client_id}) elif message_type == "history": # Send chat history await manager.send_message(client_id, { "type": "history", "messages": manager.user_sessions.get(client_id, []) }) except WebSocketDisconnect: manager.disconnect(client_id) await manager.broadcast({ "type": "user_left", "client_id": client_id }) except Exception as e: logger.error(f"WebSocket error for client {client_id}: {e}") manager.disconnect(client_id) async def handle_chat_message(client_id: str, data: dict, llm_service: LLMService): """Handle incoming chat message""" content = data.get("content", "") # Add to history manager.add_to_history(client_id, "user", content) # Echo the user message await manager.send_message(client_id, { "type": "message", "role": "user", "content": content, "timestamp": datetime.utcnow().isoformat() }) # Send typing indicator await manager.send_message(client_id, { "type": "assistant_typing", "status": "start" }) try: # Get conversation context history = manager.user_sessions.get(client_id, [])[-10:] # Last 10 messages context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history]) # Stream LLM response full_response = "" async for token in llm_service.complete( prompt=f"{context}\nassistant:", stream=True ): full_response += token # Send each token await manager.send_message(client_id, { "type": "token", "content": token }) # Small delay to prevent overwhelming await asyncio.sleep(0.01) # Add to history manager.add_to_history(client_id, "assistant", full_response) # Send completion await manager.send_message(client_id, { "type": "assistant_typing", "status": "complete" }) except Exception as e: await manager.send_message(client_id, { "type": "error", "message": f"Error generating response: {str(e)}" })
WebSocket Client Implementation
// TypeScript WebSocket client interface Message { type: 'message' | 'token' | 'error' | 'system' | 'typing' | 'history'; content?: string; role?: 'user' | 'assistant'; timestamp?: string; messages?: any[]; } class WebSocketChat { private ws: WebSocket | null = null; private clientId: string; private reconnectAttempts = 0; private maxReconnectAttempts = 5; private reconnectDelay = 1000; constructor(clientId: string) { this.clientId = clientId; } connect(): Promise<void> { return new Promise((resolve, reject) => { const wsUrl = `ws://localhost:8000/ws/chat/${this.clientId}`; this.ws = new WebSocket(wsUrl); this.ws.onopen = () => { console.log('WebSocket connected'); this.reconnectAttempts = 0; resolve(); }; this.ws.onmessage = (event) => { const message: Message = JSON.parse(event.data); this.handleMessage(message); }; this.ws.onerror = (error) => { console.error('WebSocket error:', error); reject(error); }; this.ws.onclose = () => { console.log('WebSocket disconnected'); this.handleReconnect(); }; }); } private handleMessage(message: Message) { switch (message.type) { case 'message': console.log(`${message.role}: ${message.content}`); break; case 'token': // Handle streaming token process.stdout.write(message.content || ''); break; case 'error': console.error('Error:', message.content); break; case 'system': console.log('System:', message.content); break; case 'typing': console.log('User is typing...'); break; case 'history': console.log('Chat history:', message.messages); break; } } private async handleReconnect() { if (this.reconnectAttempts < this.maxReconnectAttempts) { this.reconnectAttempts++; console.log(`Reconnecting... Attempt ${this.reconnectAttempts}`); await new Promise(resolve => setTimeout(resolve, this.reconnectDelay * this.reconnectAttempts) ); try { await this.connect(); } catch (error) { console.error('Reconnection failed:', error); } } } sendMessage(content: string) { if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.send(JSON.stringify({ type: 'message', content })); } } sendTypingIndicator() { if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.send(JSON.stringify({ type: 'typing' })); } } requestHistory() { if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.send(JSON.stringify({ type: 'history' })); } } disconnect() { if (this.ws) { this.ws.close(); } } } // Usage const chat = new WebSocketChat('user123'); await chat.connect(); // Send a message chat.sendMessage('Hello, how are you?'); // Request chat history chat.requestHistory();
Request Validation with Pydantic
Robust input validation and error handling
Advanced Validation Models
# app/models/validation.py from pydantic import BaseModel, Field, validator, root_validator from typing import Optional, List, Literal, Union import re from datetime import datetime class ChatMessageStrict(BaseModel): """Strict validation for chat messages""" role: Literal["system", "user", "assistant"] content: str = Field(..., min_length=1, max_length=10000) name: Optional[str] = Field(None, regex="^[a-zA-Z0-9_-]+$", max_length=64) @validator('content') def validate_content(cls, v): # Remove potentially harmful content if re.search(r'<script|javascript:|onerror=', v, re.IGNORECASE): raise ValueError("Content contains potentially harmful code") return v.strip() class AdvancedChatRequest(BaseModel): """Advanced chat request with comprehensive validation""" messages: List[ChatMessageStrict] = Field(..., min_items=1, max_items=50) model: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_.]+$") temperature: float = Field(0.7, ge=0, le=2) max_tokens: int = Field(1000, ge=1, le=4000) top_p: Optional[float] = Field(None, ge=0, le=1) frequency_penalty: Optional[float] = Field(None, ge=-2, le=2) presence_penalty: Optional[float] = Field(None, ge=-2, le=2) stop: Optional[Union[str, List[str]]] = Field(None) user_id: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_]+$") session_id: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_]+$") @root_validator def validate_token_limits(cls, values): messages = values.get('messages', []) max_tokens = values.get('max_tokens', 1000) # Estimate prompt tokens (rough calculation) prompt_tokens = sum(len(msg.content.split()) * 1.3 for msg in messages) if prompt_tokens + max_tokens > 4000: raise ValueError( f"Combined prompt ({int(prompt_tokens)}) and max_tokens ({max_tokens}) " f"exceeds model limit" ) return values @validator('stop') def validate_stop_sequences(cls, v): if isinstance(v, list) and len(v) > 4: raise ValueError("Maximum 4 stop sequences allowed") return v class Config: schema_extra = { "example": { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ], "temperature": 0.7, "max_tokens": 150 } } class ValidationError(BaseModel): """Structured validation error response""" loc: List[Union[str, int]] msg: str type: str ctx: Optional[dict] = None class ValidationErrorResponse(BaseModel): """Response for validation errors""" detail: List[ValidationError] error_type: str = "validation_error" timestamp: datetime = Field(default_factory=datetime.utcnow) # Custom validators def validate_api_key(api_key: str) -> str: """Validate API key format""" if not re.match(r'^pk_[a-zA-Z0-9]{32,}$', api_key): raise ValueError("Invalid API key format") return api_key def validate_model_name(model: str) -> str: """Validate model name against allowed list""" allowed_models = { "gpt-3.5-turbo", "gpt-4", "claude-3-sonnet", "claude-3-opus", "llama-2-70b" } if model not in allowed_models: raise ValueError(f"Model '{model}' not supported. Allowed: {allowed_models}") return model
Validation Middleware
# app/middleware/validation.py from fastapi import Request, status from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from pydantic import ValidationError import logging logger = logging.getLogger(__name__) async def validation_exception_handler(request: Request, exc: RequestValidationError): """Custom handler for validation errors""" # Log validation error logger.warning( f"Validation error on {request.method} {request.url.path}: " f"{exc.errors()}" ) # Format errors for response formatted_errors = [] for error in exc.errors(): formatted_errors.append({ "loc": error["loc"], "msg": error["msg"], "type": error["type"], "ctx": error.get("ctx") }) return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ "detail": formatted_errors, "error_type": "validation_error", "path": str(request.url.path), "method": request.method } ) # Input sanitization from html import escape import bleach class InputSanitizer: """Sanitize user inputs""" ALLOWED_TAGS = ['b', 'i', 'u', 'code', 'pre'] ALLOWED_ATTRIBUTES = {} @staticmethod def sanitize_text(text: str) -> str: """Sanitize text input""" # Basic HTML escaping text = escape(text) # Allow some safe tags text = bleach.clean( text, tags=InputSanitizer.ALLOWED_TAGS, attributes=InputSanitizer.ALLOWED_ATTRIBUTES, strip=True ) # Remove any remaining suspicious patterns suspicious_patterns = [ r'javascript:', r'vbscript:', r'onload=', r'onerror=', r'<iframe', r'<object', r'<embed' ] for pattern in suspicious_patterns: text = re.sub(pattern, '', text, flags=re.IGNORECASE) return text.strip() # Request size limiting async def limit_request_size(request: Request, call_next): """Limit request body size""" MAX_REQUEST_SIZE = 1_000_000 # 1MB if request.headers.get("content-length"): content_length = int(request.headers["content-length"]) if content_length > MAX_REQUEST_SIZE: return JSONResponse( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, content={ "detail": f"Request body too large. Maximum size: {MAX_REQUEST_SIZE} bytes" } ) response = await call_next(request) return response
Authentication & Rate Limiting
Securing your API with auth and rate limits
JWT Authentication
# app/services/auth.py from datetime import datetime, timedelta from typing import Optional from jose import JWTError, jwt from passlib.context import CryptContext from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.config import settings import redis pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") security = HTTPBearer() class AuthService: """Handle authentication and authorization""" def __init__(self, redis_client: redis.Redis): self.redis_client = redis_client def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None): """Create JWT access token""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) return encoded_jwt async def verify_token(self, credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify JWT token""" token = credentials.credentials try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) user_id: str = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) # Check if token is blacklisted if self.redis_client.get(f"blacklist:{token}"): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked" ) return {"user_id": user_id, "token": token} except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) def revoke_token(self, token: str): """Add token to blacklist""" # Token will be blacklisted until it expires ttl = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 self.redis_client.setex(f"blacklist:{token}", ttl, "1") # API Key authentication class APIKeyAuth: """API Key authentication""" def __init__(self, redis_client: redis.Redis): self.redis_client = redis_client async def verify_api_key(self, api_key: str = Depends(security)): """Verify API key""" key = api_key.credentials # Check if key exists and is valid key_data = self.redis_client.hgetall(f"api_key:{key}") if not key_data: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" ) # Check if key is active if key_data.get(b"status") != b"active": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="API key is not active" ) # Update last used timestamp self.redis_client.hset(f"api_key:{key}", "last_used", datetime.utcnow().isoformat()) return { "api_key": key, "user_id": key_data.get(b"user_id", b"").decode(), "rate_limit": int(key_data.get(b"rate_limit", b"60")) }
Rate Limiting Implementation
# app/middleware/rate_limit.py from fastapi import Request, Response, HTTPException, status from typing import Callable, Optional import time import redis from datetime import datetime, timedelta class RateLimiter: """Token bucket rate limiter using Redis""" def __init__(self, redis_client: redis.Redis): self.redis = redis_client async def check_rate_limit( self, key: str, max_requests: int, window_seconds: int ) -> tuple[bool, dict]: """ Check if request is within rate limit Returns (allowed, info) """ now = time.time() window_start = now - window_seconds pipe = self.redis.pipeline() pipe.zremrangebyscore(key, 0, window_start) pipe.zadd(key, {str(now): now}) pipe.zcount(key, window_start, now) pipe.expire(key, window_seconds + 1) results = pipe.execute() request_count = results[2] info = { "limit": max_requests, "remaining": max(0, max_requests - request_count), "reset": int(now + window_seconds) } return request_count <= max_requests, info class RateLimitMiddleware: """Rate limiting middleware""" def __init__(self, redis_client: redis.Redis): self.limiter = RateLimiter(redis_client) # Different rate limits for different endpoints self.limits = { "/api/v1/chat/completions": (60, 60), # 60 requests per minute "/api/v1/stream/chat": (30, 60), # 30 requests per minute "/ws/chat": (10, 60), # 10 connections per minute "default": (100, 60) # Default: 100 per minute } async def __call__(self, request: Request, call_next: Callable): # Skip rate limiting for health checks if request.url.path in ["/health", "/metrics"]: return await call_next(request) # Get identifier (API key or IP) identifier = self.get_identifier(request) # Get rate limit for endpoint limit, window = self.limits.get(request.url.path, self.limits["default"]) # Check rate limit key = f"rate_limit:{identifier}:{request.url.path}" allowed, info = await self.limiter.check_rate_limit(key, limit, window) # Add rate limit headers response = Response() response.headers["X-RateLimit-Limit"] = str(info["limit"]) response.headers["X-RateLimit-Remaining"] = str(info["remaining"]) response.headers["X-RateLimit-Reset"] = str(info["reset"]) if not allowed: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded", headers={ "X-RateLimit-Limit": str(info["limit"]), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(info["reset"]), "Retry-After": str(info["reset"] - int(time.time())) } ) # Process request response = await call_next(request) # Copy rate limit headers to response response.headers["X-RateLimit-Limit"] = str(info["limit"]) response.headers["X-RateLimit-Remaining"] = str(info["remaining"]) response.headers["X-RateLimit-Reset"] = str(info["reset"]) return response def get_identifier(self, request: Request) -> str: """Get identifier for rate limiting""" # Try to get API key auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return f"key:{auth_header[7:]}" # Fall back to IP address return f"ip:{request.client.host}" # Usage-based rate limiting class UsageBasedLimiter: """Rate limit based on token usage""" def __init__(self, redis_client: redis.Redis): self.redis = redis_client async def check_token_limit( self, user_id: str, tokens_requested: int, daily_limit: int = 100000 ) -> tuple[bool, dict]: """Check if user has tokens remaining""" today = datetime.utcnow().strftime("%Y-%m-%d") key = f"usage:{user_id}:{today}" # Get current usage current_usage = int(self.redis.get(key) or 0) if current_usage + tokens_requested > daily_limit: return False, { "daily_limit": daily_limit, "used": current_usage, "remaining": daily_limit - current_usage, "requested": tokens_requested } # Increment usage self.redis.incrby(key, tokens_requested) self.redis.expire(key, 86400) # Expire after 24 hours return True, { "daily_limit": daily_limit, "used": current_usage + tokens_requested, "remaining": daily_limit - (current_usage + tokens_requested) }
Always use HTTPS in production and store secrets securely. Consider implementing OAuth2 for more complex authentication scenarios.
Background Tasks for Async Processing
Handling long-running tasks without blocking responses
FastAPI Background Tasks
# app/routers/async_processing.py from fastapi import APIRouter, BackgroundTasks, Depends from app.models import ChatRequest import uuid from typing import Dict import asyncio router = APIRouter() # In-memory task storage (use Redis in production) task_storage: Dict[str, dict] = {} async def process_llm_task(task_id: str, request: ChatRequest): """Background task for processing LLM request""" try: # Update task status task_storage[task_id]["status"] = "processing" task_storage[task_id]["started_at"] = datetime.utcnow() # Simulate LLM processing await asyncio.sleep(2) # Replace with actual LLM call # Get LLM response llm_service = LLMService(httpx.AsyncClient()) result = await llm_service.complete( prompt=request.messages[-1].content, model=request.model, temperature=request.temperature, max_tokens=request.max_tokens ) # Update task with result task_storage[task_id].update({ "status": "completed", "result": result, "completed_at": datetime.utcnow() }) except Exception as e: task_storage[task_id].update({ "status": "failed", "error": str(e), "failed_at": datetime.utcnow() }) @router.post("/async-chat") async def create_async_chat( request: ChatRequest, background_tasks: BackgroundTasks ): """Submit chat request for async processing""" # Generate task ID task_id = str(uuid.uuid4()) # Store initial task data task_storage[task_id] = { "id": task_id, "status": "pending", "created_at": datetime.utcnow(), "request": request.dict() } # Add to background tasks background_tasks.add_task(process_llm_task, task_id, request) return { "task_id": task_id, "status": "pending", "check_url": f"/api/v1/tasks/{task_id}" } @router.get("/tasks/{task_id}") async def get_task_status(task_id: str): """Check status of async task""" if task_id not in task_storage: raise HTTPException( status_code=404, detail="Task not found" ) task = task_storage[task_id] # Calculate duration if completed if task["status"] == "completed" and "started_at" in task: duration = (task["completed_at"] - task["started_at"]).total_seconds() task["duration_seconds"] = duration return task
Celery Integration
# app/celery_app.py from celery import Celery from app.config import settings import httpx import asyncio # Create Celery app celery_app = Celery( "llm_tasks", broker=settings.REDIS_URL, backend=settings.REDIS_URL ) # Configure Celery celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', timezone='UTC', enable_utc=True, result_expires=3600, # Results expire after 1 hour task_soft_time_limit=300, # 5 minutes soft limit task_time_limit=600, # 10 minutes hard limit ) @celery_app.task(bind=True, name="process_llm_request") def process_llm_request(self, prompt: str, **kwargs): """Celery task for processing LLM requests""" try: # Update task state self.update_state( state='PROCESSING', meta={'current': 'Calling LLM API'} ) # Make synchronous call (Celery doesn't support async) with httpx.Client(timeout=30.0) as client: response = client.post( f"{settings.LLM_API_BASE_URL}/chat/completions", headers={ "Authorization": f"Bearer {settings.LLM_API_KEY}", "Content-Type": "application/json" }, json={ "model": kwargs.get("model", settings.LLM_MODEL), "messages": [{"role": "user", "content": prompt}], "temperature": kwargs.get("temperature", 0.7), "max_tokens": kwargs.get("max_tokens", 1000) } ) response.raise_for_status() result = response.json() return { 'success': True, 'response': result['choices'][0]['message']['content'], 'usage': result.get('usage', {}), 'model': result.get('model') } except Exception as e: # Log error and re-raise self.update_state( state='FAILURE', meta={'error': str(e)} ) raise # Batch processing task @celery_app.task(name="batch_process_prompts") def batch_process_prompts(prompts: List[str], **kwargs): """Process multiple prompts in batch""" results = [] for i, prompt in enumerate(prompts): # Update progress celery_app.current_task.update_state( state='PROGRESS', meta={ 'current': i + 1, 'total': len(prompts), 'percentage': ((i + 1) / len(prompts)) * 100 } ) # Process each prompt result = process_llm_request.apply_async( args=[prompt], kwargs=kwargs ).get() results.append(result) return results # FastAPI endpoint for Celery tasks @router.post("/celery/submit") async def submit_celery_task(request: ChatRequest): """Submit task to Celery""" # Submit to Celery task = process_llm_request.apply_async( args=[request.messages[-1].content], kwargs={ "model": request.model, "temperature": request.temperature, "max_tokens": request.max_tokens } ) return { "task_id": task.id, "status": "submitted", "check_url": f"/api/v1/celery/status/{task.id}" } @router.get("/celery/status/{task_id}") async def get_celery_task_status(task_id: str): """Get Celery task status""" task = process_llm_request.AsyncResult(task_id) if task.state == 'PENDING': response = { 'state': task.state, 'status': 'Task is waiting to be processed' } elif task.state == 'PROCESSING': response = { 'state': task.state, 'current': task.info.get('current', ''), 'status': 'Task is being processed' } elif task.state == 'SUCCESS': response = { 'state': task.state, 'result': task.result, 'status': 'Task completed successfully' } else: # FAILURE response = { 'state': task.state, 'error': str(task.info), 'status': 'Task failed' } return response
Error Handling and Retries
Robust error handling for production systems
Comprehensive Error Handling
# app/exceptions.py from fastapi import Request, status from fastapi.responses import JSONResponse from typing import Union import traceback import logging logger = logging.getLogger(__name__) class LLMException(Exception): """Base exception for LLM-related errors""" def __init__(self, message: str, status_code: int = 500, details: dict = None): self.message = message self.status_code = status_code self.details = details or {} super().__init__(self.message) class RateLimitException(LLMException): """Rate limit exceeded""" def __init__(self, message: str = "Rate limit exceeded", retry_after: int = 60): super().__init__(message, status_code=429) self.retry_after = retry_after class ModelNotFoundException(LLMException): """Model not found""" def __init__(self, model: str): super().__init__( f"Model '{model}' not found", status_code=404, details={"model": model} ) class TokenLimitException(LLMException): """Token limit exceeded""" def __init__(self, requested: int, limit: int): super().__init__( f"Token limit exceeded. Requested: {requested}, Limit: {limit}", status_code=400, details={"requested": requested, "limit": limit} ) # Exception handlers async def llm_exception_handler(request: Request, exc: LLMException): """Handle LLM-specific exceptions""" logger.error( f"LLM error on {request.method} {request.url.path}: " f"{exc.message} - Details: {exc.details}" ) response = JSONResponse( status_code=exc.status_code, content={ "error": exc.message, "details": exc.details, "type": type(exc).__name__, "path": str(request.url.path) } ) if isinstance(exc, RateLimitException): response.headers["Retry-After"] = str(exc.retry_after) return response async def general_exception_handler(request: Request, exc: Exception): """Handle unexpected exceptions""" logger.error( f"Unexpected error on {request.method} {request.url.path}: " f"{str(exc)}\n{traceback.format_exc()}" ) # Don't expose internal errors in production if settings.DEBUG: content = { "error": "Internal server error", "detail": str(exc), "type": type(exc).__name__, "traceback": traceback.format_exc().split("\n") } else: content = { "error": "Internal server error", "detail": "An unexpected error occurred", "request_id": request.state.request_id if hasattr(request.state, 'request_id') else None } return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=content ) # Retry mechanism from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, before_sleep_log ) def create_retry_decorator(max_attempts: int = 3): """Create a retry decorator with exponential backoff""" return retry( stop=stop_after_attempt(max_attempts), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((httpx.HTTPError, asyncio.TimeoutError)), before_sleep=before_sleep_log(logger, logging.WARNING) ) # Circuit breaker pattern from typing import Callable import time class CircuitBreaker: """Circuit breaker for external service calls""" def __init__( self, failure_threshold: int = 5, recovery_timeout: int = 60, expected_exception: type = Exception ): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.expected_exception = expected_exception self.failure_count = 0 self.last_failure_time = None self.state = "closed" # closed, open, half-open async def call(self, func: Callable, *args, **kwargs): """Execute function with circuit breaker""" if self.state == "open": if time.time() - self.last_failure_time > self.recovery_timeout: self.state = "half-open" else: raise LLMException( "Circuit breaker is open - service temporarily unavailable", status_code=503 ) try: result = await func(*args, **kwargs) if self.state == "half-open": self.state = "closed" self.failure_count = 0 return result except self.expected_exception as e: self.failure_count += 1 self.last_failure_time = time.time() if self.failure_count >= self.failure_threshold: self.state = "open" logger.error(f"Circuit breaker opened after {self.failure_count} failures") raise # Usage in service class ResilientLLMService: """LLM service with retry and circuit breaker""" def __init__(self): self.circuit_breaker = CircuitBreaker( failure_threshold=5, recovery_timeout=60, expected_exception=httpx.HTTPError ) self.retry_decorator = create_retry_decorator(max_attempts=3) @create_retry_decorator(max_attempts=3) async def complete_with_retry(self, prompt: str, **kwargs): """Complete with automatic retry""" async def _make_request(): # Your LLM API call here pass return await self.circuit_breaker.call(_make_request)
Production Deployment
Deploying FastAPI with Uvicorn and Gunicorn
Production Server Configuration
# gunicorn.conf.py import multiprocessing import os # Server socket bind = f"0.0.0.0:{os.environ.get('PORT', '8000')}" backlog = 2048 # Worker processes workers = multiprocessing.cpu_count() * 2 + 1 worker_class = "uvicorn.workers.UvicornWorker" worker_connections = 1000 keepalive = 5 max_requests = 1000 max_requests_jitter = 50 # Timeouts timeout = 30 graceful_timeout = 30 # Logging accesslog = "-" errorlog = "-" loglevel = "info" access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(D)s' # Process naming proc_name = 'llm-api' # Server mechanics daemon = False pidfile = None user = None group = None tmp_upload_dir = None # SSL (if needed) # keyfile = "/path/to/keyfile" # certfile = "/path/to/certfile" # Server hooks def when_ready(server): server.log.info("Server is ready. Spawning workers") def worker_int(worker): worker.log.info("Worker received INT or QUIT signal") def pre_fork(server, worker): server.log.info(f"Worker spawned (pid: {worker.pid})") def pre_exec(server): server.log.info("Forked child, re-executing.") def worker_abort(worker): worker.log.info("Worker received SIGABRT signal")
Docker Configuration
# Dockerfile FROM python:3.11-slim as builder # Install system dependencies RUN apt-get update && apt-get install -y \ gcc \ g++ \ && rm -rf /var/lib/apt/lists/* # Set working directory WORKDIR /app # Copy requirements COPY requirements.txt . # Install Python dependencies RUN pip install --no-cache-dir --upgrade pip && \ pip install --no-cache-dir -r requirements.txt # Production stage FROM python:3.11-slim # Install runtime dependencies RUN apt-get update && apt-get install -y \ curl \ && rm -rf /var/lib/apt/lists/* # Create non-root user RUN useradd -m -u 1000 appuser WORKDIR /app # Copy Python dependencies from builder COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages COPY --from=builder /usr/local/bin /usr/local/bin # Copy application code COPY --chown=appuser:appuser . . # Switch to non-root user USER appuser # Environment variables ENV PYTHONUNBUFFERED=1 \ PYTHONDONTWRITEBYTECODE=1 \ PORT=8000 # Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \ CMD curl -f http://localhost:$PORT/health || exit 1 # Expose port EXPOSE $PORT # Run with gunicorn CMD ["gunicorn", "app.main:app", "-c", "gunicorn.conf.py"]
Docker Compose for Development
# docker-compose.yml version: '3.8' services: api: build: . ports: - "8000:8000" environment: - REDIS_URL=redis://redis:6379 - DATABASE_URL=postgresql://user:pass@postgres:5432/llm_db - LLM_API_KEY=${LLM_API_KEY} - SECRET_KEY=${SECRET_KEY} depends_on: - redis - postgres volumes: - ./app:/app/app # For development hot reload command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload redis: image: redis:7-alpine ports: - "6379:6379" volumes: - redis_data:/data postgres: image: postgres:15-alpine environment: POSTGRES_USER: user POSTGRES_PASSWORD: pass POSTGRES_DB: llm_db ports: - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data nginx: image: nginx:alpine ports: - "80:80" - "443:443" volumes: - ./nginx.conf:/etc/nginx/nginx.conf - ./ssl:/etc/nginx/ssl depends_on: - api prometheus: image: prom/prometheus ports: - "9090:9090" volumes: - ./prometheus.yml:/etc/prometheus/prometheus.yml - prometheus_data:/prometheus grafana: image: grafana/grafana ports: - "3000:3000" environment: - GF_SECURITY_ADMIN_PASSWORD=admin volumes: - grafana_data:/var/lib/grafana volumes: redis_data: postgres_data: prometheus_data: grafana_data:
Nginx Configuration
# nginx.conf events { worker_connections 1024; } http { upstream fastapi_backend { least_conn; server api:8000 max_fails=3 fail_timeout=30s; keepalive 32; } # Rate limiting limit_req_zone $binary_remote_addr zone=general:10m rate=10r/s; limit_req_zone $binary_remote_addr zone=api:10m rate=100r/m; # Connection limiting limit_conn_zone $binary_remote_addr zone=addr:10m; server { listen 80; server_name api.yourdomain.com; return 301 https://$server_name$request_uri; } server { listen 443 ssl http2; server_name api.yourdomain.com; # SSL configuration ssl_certificate /etc/nginx/ssl/cert.pem; ssl_certificate_key /etc/nginx/ssl/key.pem; ssl_protocols TLSv1.2 TLSv1.3; ssl_ciphers ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384; ssl_prefer_server_ciphers off; # Security headers add_header X-Frame-Options "SAMEORIGIN" always; add_header X-Content-Type-Options "nosniff" always; add_header X-XSS-Protection "1; mode=block" always; add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; # Timeouts proxy_connect_timeout 60s; proxy_send_timeout 60s; proxy_read_timeout 60s; # Buffer sizes client_body_buffer_size 1M; client_max_body_size 10M; # Rate limiting limit_req zone=general burst=20 nodelay; limit_conn addr 10; location / { proxy_pass http://fastapi_backend; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; # Disable buffering for SSE proxy_buffering off; proxy_cache off; # Timeouts for streaming proxy_read_timeout 3600s; keepalive_timeout 3600s; } location /api/ { limit_req zone=api burst=5 nodelay; proxy_pass http://fastapi_backend; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; } location /ws/ { proxy_pass http://fastapi_backend; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; # WebSocket specific proxy_read_timeout 3600s; proxy_send_timeout 3600s; } location /health { proxy_pass http://fastapi_backend; access_log off; } } }
Scaling Strategies
Horizontal and vertical scaling for high traffic
Kubernetes Deployment
# k8s-deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: fastapi-llm labels: app: fastapi-llm spec: replicas: 3 selector: matchLabels: app: fastapi-llm template: metadata: labels: app: fastapi-llm spec: containers: - name: fastapi image: your-registry/fastapi-llm:latest ports: - containerPort: 8000 env: - name: REDIS_URL value: redis://redis-service:6379 - name: LLM_API_KEY valueFrom: secretKeyRef: name: llm-secrets key: api-key resources: requests: memory: "512Mi" cpu: "500m" limits: memory: "1Gi" cpu: "1000m" livenessProbe: httpGet: path: /health port: 8000 initialDelaySeconds: 30 periodSeconds: 10 readinessProbe: httpGet: path: /health port: 8000 initialDelaySeconds: 5 periodSeconds: 5 --- apiVersion: v1 kind: Service metadata: name: fastapi-service spec: selector: app: fastapi-llm ports: - port: 80 targetPort: 8000 type: LoadBalancer --- apiVersion: autoscaling/v2 kind: HorizontalPodAutoscaler metadata: name: fastapi-hpa spec: scaleTargetRef: apiVersion: apps/v1 kind: Deployment name: fastapi-llm minReplicas: 3 maxReplicas: 10 metrics: - type: Resource resource: name: cpu target: type: Utilization averageUtilization: 70 - type: Resource resource: name: memory target: type: Utilization averageUtilization: 80
Horizontal Scaling
Load balancer distribution
Auto-scaling based on metrics
Session affinity for WebSockets
Shared Redis for state
Vertical Scaling
Increase worker processes
Optimize memory usage
Use connection pooling
Enable response caching
Monitoring & Logging
Observability for production FastAPI applications
Prometheus Metrics
# app/monitoring.py from prometheus_fastapi_instrumentator import Instrumentator from prometheus_client import Counter, Histogram, Gauge, Info import time # Custom metrics llm_requests_total = Counter( 'llm_requests_total', 'Total LLM API requests', ['model', 'endpoint', 'status'] ) llm_request_duration = Histogram( 'llm_request_duration_seconds', 'LLM request duration', ['model', 'endpoint'], buckets=[0.1, 0.5, 1, 2, 5, 10, 30] ) llm_tokens_used = Counter( 'llm_tokens_used_total', 'Total tokens used', ['model', 'type'] # type: prompt, completion ) active_websocket_connections = Gauge( 'active_websocket_connections', 'Number of active WebSocket connections' ) app_info = Info('fastapi_app', 'Application info') app_info.info({ 'version': '1.0.0', 'environment': settings.ENVIRONMENT }) # Instrumentator setup instrumentator = Instrumentator( should_group_status_codes=False, should_ignore_untemplated=True, should_respect_env_var=True, should_instrument_requests_inprogress=True, excluded_handlers=[".*health.*", ".*metrics.*"], inprogress_name="fastapi_inprogress", inprogress_labels=True ) # Custom metric middleware async def record_llm_metrics(request: Request, call_next): """Record custom LLM metrics""" start_time = time.time() # Add request ID for tracing request.state.request_id = str(uuid.uuid4()) response = await call_next(request) # Record metrics if request.url.path.startswith("/api/v1/"): duration = time.time() - start_time endpoint = request.url.path model = "unknown" # Extract from request/response llm_requests_total.labels( model=model, endpoint=endpoint, status=response.status_code ).inc() llm_request_duration.labels( model=model, endpoint=endpoint ).observe(duration) return response # Structured logging import structlog from pythonjsonlogger import jsonlogger def setup_logging(): """Configure structured logging""" # Configure structlog structlog.configure( processors=[ structlog.stdlib.filter_by_level, structlog.stdlib.add_logger_name, structlog.stdlib.add_log_level, structlog.stdlib.PositionalArgumentsFormatter(), structlog.processors.TimeStamper(fmt="iso"), structlog.processors.StackInfoRenderer(), structlog.processors.format_exc_info, structlog.processors.UnicodeDecoder(), structlog.processors.JSONRenderer() ], context_class=dict, logger_factory=structlog.stdlib.LoggerFactory(), cache_logger_on_first_use=True, ) # Configure standard library logging logHandler = logging.StreamHandler() formatter = jsonlogger.JsonFormatter( fmt="%(timestamp)s %(level)s %(name)s %(message)s", timestamp=True ) logHandler.setFormatter(formatter) logging.basicConfig( level=settings.LOG_LEVEL, handlers=[logHandler] ) # Request ID middleware from asgi_correlation_id import CorrelationIdMiddleware from asgi_correlation_id.context import correlation_id app.add_middleware( CorrelationIdMiddleware, header_name='X-Request-ID', generator=lambda: str(uuid.uuid4()) ) # Logging middleware @app.middleware("http") async def log_requests(request: Request, call_next): """Log all requests with context""" logger = structlog.get_logger() # Log request await logger.ainfo( "request_started", request_id=correlation_id.get(), method=request.method, path=request.url.path, client=request.client.host if request.client else None ) start_time = time.time() response = await call_next(request) duration = time.time() - start_time # Log response await logger.ainfo( "request_completed", request_id=correlation_id.get(), method=request.method, path=request.url.path, status_code=response.status_code, duration=duration ) return response
Use centralized logging (ELK stack, Datadog, etc.) and distributed tracing (Jaeger, Zipkin) for production environments.
References
Additional resources and documentation
Official Documentation
FastAPI's async capabilities make it ideal for LLM streaming applications. Always test your streaming implementations across different deployment environments.