FastAPI + LLM Streaming Integration
Build high-performance, async LLM APIs with real-time streaming capabilities
Table of Contents
Overview
Why FastAPI is ideal for LLM integrations
FastAPI is a modern Python web framework that excels at building high-performance APIs with async support, automatic validation, and built-in documentation. It's particularly well-suited for LLM integrations due to its streaming capabilities and async-first design.
Key Features
- • Native async/await support
- • Automatic request validation
- • Built-in OpenAPI docs
- • WebSocket & SSE support
LLM Benefits
- • Stream tokens in real-time
- • Handle concurrent requests
- • Async LLM API calls
- • Efficient resource usage
Setup & Configuration
Installing FastAPI and required dependencies
Installation
# Core dependencies pip install fastapi uvicorn[standard] httpx python-multipart # Additional dependencies for production pip install gunicorn python-jose[cryptography] passlib[bcrypt] pip install redis celery flower # For background tasks pip install prometheus-fastapi-instrumentator # Monitoring # For development pip install pytest pytest-asyncio black isort
Project Structure
llm-api/ ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI app │ ├── config.py # Configuration │ ├── models.py # Pydantic models │ ├── routers/ │ │ ├── __init__.py │ │ ├── chat.py # Chat endpoints │ │ ├── streaming.py # Streaming endpoints │ │ └── websocket.py # WebSocket endpoints │ ├── services/ │ │ ├── __init__.py │ │ ├── llm.py # LLM service │ │ └── auth.py # Authentication │ ├── middleware/ │ │ ├── __init__.py │ │ ├── rate_limit.py │ │ └── logging.py │ └── utils/ │ ├── __init__.py │ └── helpers.py ├── tests/ ├── docker-compose.yml ├── Dockerfile └── requirements.txt
Basic FastAPI App
# app/main.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import httpx
from app.config import settings
from app.routers import chat, streaming, websocket
# Lifespan context manager for startup/shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
app.state.http_client = httpx.AsyncClient(timeout=30.0)
yield
# Shutdown
await app.state.http_client.aclose()
# Create FastAPI app
app = FastAPI(
title="LLM Streaming API",
description="High-performance LLM API with streaming support",
version="1.0.0",
lifespan=lifespan
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(chat.router, prefix="/api/v1/chat", tags=["chat"])
app.include_router(streaming.router, prefix="/api/v1/stream", tags=["streaming"])
app.include_router(websocket.router, prefix="/ws", tags=["websocket"])
@app.get("/")
async def root():
return {
"message": "LLM Streaming API",
"docs": "/docs",
"health": "/health"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"version": "1.0.0"
}Configuration Management
# app/config.py
from pydantic_settings import BaseSettings
from typing import List, Optional
import os
class Settings(BaseSettings):
# API Settings
API_V1_STR: str = "/api/v1"
PROJECT_NAME: str = "LLM Streaming API"
# CORS
ALLOWED_ORIGINS: List[str] = ["http://localhost:3000", "https://yourdomain.com"]
# LLM Provider Settings
LLM_API_BASE_URL: str = "https://api.parrotrouter.com/v1"
LLM_API_KEY: str
LLM_MODEL: str = "gpt-3.5-turbo"
LLM_MAX_TOKENS: int = 2000
LLM_TEMPERATURE: float = 0.7
# Redis
REDIS_URL: str = "redis://localhost:6379"
# Authentication
SECRET_KEY: str
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
# Rate Limiting
RATE_LIMIT_PER_MINUTE: int = 60
RATE_LIMIT_PER_HOUR: int = 1000
# Monitoring
ENABLE_METRICS: bool = True
LOG_LEVEL: str = "INFO"
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()Basic LLM Integration
Implementing standard LLM API calls with FastAPI
LLM Service
# app/services/llm.py
import httpx
from typing import Dict, Any, Optional, AsyncGenerator
import json
from app.config import settings
import logging
logger = logging.getLogger(__name__)
class LLMService:
"""Service for interacting with LLM APIs"""
def __init__(self, http_client: httpx.AsyncClient):
self.http_client = http_client
self.base_url = settings.LLM_API_BASE_URL
self.headers = {
"Authorization": f"Bearer {settings.LLM_API_KEY}",
"Content-Type": "application/json"
}
async def complete(self,
prompt: str,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False) -> Dict[str, Any]:
"""
Make a completion request to the LLM API
"""
payload = {
"model": model or settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature or settings.LLM_TEMPERATURE,
"max_tokens": max_tokens or settings.LLM_MAX_TOKENS,
"stream": stream
}
try:
if stream:
return await self._stream_completion(payload)
else:
response = await self.http_client.post(
f"{self.base_url}/chat/completions",
headers=self.headers,
json=payload
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"LLM API error: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Unexpected error calling LLM: {e}")
raise
async def _stream_completion(self, payload: Dict[str, Any]) -> AsyncGenerator[str, None]:
"""
Stream completion from LLM API
"""
async with self.http_client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=self.headers,
json=payload
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
try:
chunk = json.loads(data)
if chunk["choices"][0]["delta"].get("content"):
yield chunk["choices"][0]["delta"]["content"]
except json.JSONDecodeError:
logger.warning(f"Failed to parse streaming chunk: {data}")
continuePydantic Models
# app/models.py
from pydantic import BaseModel, Field, validator
from typing import Optional, List, Literal
from datetime import datetime
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: Optional[str] = Field(None, description="Model to use")
temperature: Optional[float] = Field(None, ge=0, le=2)
max_tokens: Optional[int] = Field(None, ge=1, le=4000)
stream: bool = Field(False, description="Enable streaming response")
@validator('messages')
def validate_messages(cls, v):
if not v:
raise ValueError("Messages cannot be empty")
return v
class ChatResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[dict]
usage: dict
class StreamingChatRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=10000)
system_prompt: Optional[str] = None
temperature: Optional[float] = Field(0.7, ge=0, le=2)
max_tokens: Optional[int] = Field(1000, ge=1, le=4000)
class ErrorResponse(BaseModel):
error: str
detail: Optional[str] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)Chat Router
# app/routers/chat.py
from fastapi import APIRouter, HTTPException, Request, Depends
from app.models import ChatRequest, ChatResponse, ErrorResponse
from app.services.llm import LLMService
from typing import Annotated
router = APIRouter()
async def get_llm_service(request: Request) -> LLMService:
"""Dependency to get LLM service"""
return LLMService(request.app.state.http_client)
@router.post("/completions",
response_model=ChatResponse,
responses={
400: {"model": ErrorResponse},
500: {"model": ErrorResponse}
})
async def create_chat_completion(
request: ChatRequest,
llm_service: Annotated[LLMService, Depends(get_llm_service)]
):
"""
Create a chat completion
- **messages**: List of chat messages
- **model**: Optional model override
- **temperature**: Sampling temperature (0-2)
- **max_tokens**: Maximum tokens to generate
"""
try:
# Convert messages to the format expected by LLM
prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
result = await llm_service.complete(
prompt=prompt,
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=False
)
return result
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
detail=f"LLM API error: {e.response.text}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)Streaming with Server-Sent Events (SSE)
Implementing real-time token streaming using SSE
SSE Streaming Endpoint
# app/routers/streaming.py
from fastapi import APIRouter, Request, Depends
from fastapi.responses import StreamingResponse
from app.models import StreamingChatRequest
from app.services.llm import LLMService
from typing import AsyncGenerator
import json
import asyncio
router = APIRouter()
async def get_llm_service(request: Request) -> LLMService:
return LLMService(request.app.state.http_client)
@router.post("/chat")
async def stream_chat(
request: StreamingChatRequest,
llm_service: Annotated[LLMService, Depends(get_llm_service)]
):
"""
Stream chat responses using Server-Sent Events
"""
async def generate_sse() -> AsyncGenerator[str, None]:
try:
# Send initial event
yield f"data: {json.dumps({'type': 'start', 'message': 'Starting generation'})}\n\n"
# Prepare the prompt
full_prompt = request.prompt
if request.system_prompt:
full_prompt = f"System: {request.system_prompt}\n\nUser: {request.prompt}"
# Stream tokens from LLM
token_count = 0
async for token in llm_service.complete(
prompt=full_prompt,
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=True
):
token_count += 1
yield f"data: {json.dumps({'type': 'token', 'content': token, 'index': token_count})}\n\n"
# Small delay to prevent overwhelming the client
await asyncio.sleep(0.01)
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'token_count': token_count})}\n\n"
except Exception as e:
# Send error event
yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
finally:
# Send done event
yield f"data: [DONE]\n\n"
return StreamingResponse(
generate_sse(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable proxy buffering
}
)
@router.get("/test")
async def test_stream():
"""
Test SSE endpoint that streams numbers
"""
async def generate():
for i in range(10):
yield f"data: {json.dumps({'number': i, 'timestamp': str(datetime.utcnow())})}\n\n"
await asyncio.sleep(1)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)Frontend SSE Consumer
// Frontend code to consume SSE stream
class SSEClient {
private eventSource: EventSource | null = null;
async streamChat(prompt: string, onToken: (token: string) => void) {
const response = await fetch('/api/v1/stream/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ prompt })
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body?.getReader();
const decoder = new TextDecoder();
if (!reader) return;
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
return;
}
try {
const parsed = JSON.parse(data);
if (parsed.type === 'token') {
onToken(parsed.content);
}
} catch (e) {
console.error('Failed to parse SSE data:', e);
}
}
}
}
}
}
// Usage
const client = new SSEClient();
let fullResponse = '';
await client.streamChat('Tell me a story', (token) => {
fullResponse += token;
console.log('Token:', token);
});Advanced SSE Features
# Advanced SSE implementation with retry and heartbeat
from datetime import datetime
import uuid
class SSEResponse:
"""Helper class for SSE responses"""
def __init__(self, retry: int = 3000):
self.retry = retry
def format_sse(self, data: str, event: Optional[str] = None, id: Optional[str] = None) -> str:
"""Format data as SSE"""
lines = []
if id:
lines.append(f"id: {id}")
if event:
lines.append(f"event: {event}")
if self.retry:
lines.append(f"retry: {self.retry}")
# Split data by newlines and prefix each line
for line in data.splitlines():
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
@router.post("/chat/advanced")
async def advanced_stream_chat(request: StreamingChatRequest):
"""Advanced SSE streaming with heartbeat and metadata"""
sse = SSEResponse(retry=5000)
request_id = str(uuid.uuid4())
async def generate():
try:
# Send metadata
metadata = {
"request_id": request_id,
"model": settings.LLM_MODEL,
"timestamp": datetime.utcnow().isoformat()
}
yield sse.format_sse(
json.dumps(metadata),
event="metadata",
id=request_id
)
# Start heartbeat task
heartbeat_task = asyncio.create_task(heartbeat_generator())
# Stream LLM response
token_buffer = []
async for token in llm_service.complete(
prompt=request.prompt,
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=True
):
token_buffer.append(token)
# Send tokens in batches for efficiency
if len(token_buffer) >= 5:
yield sse.format_sse(
json.dumps({
"tokens": token_buffer,
"partial": "".join(token_buffer)
}),
event="tokens"
)
token_buffer = []
# Send remaining tokens
if token_buffer:
yield sse.format_sse(
json.dumps({
"tokens": token_buffer,
"partial": "".join(token_buffer)
}),
event="tokens"
)
# Cancel heartbeat
heartbeat_task.cancel()
# Send completion
yield sse.format_sse(
json.dumps({"status": "complete"}),
event="done"
)
except Exception as e:
yield sse.format_sse(
json.dumps({
"error": str(e),
"type": type(e).__name__
}),
event="error"
)
async def heartbeat_generator():
"""Send periodic heartbeat to keep connection alive"""
while True:
await asyncio.sleep(30)
yield sse.format_sse(
json.dumps({"type": "heartbeat"}),
event="heartbeat"
)
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Request-ID": request_id
}
)WebSocket Implementation
Building real-time bidirectional chat with WebSockets
WebSocket Chat Handler
# app/routers/websocket.py
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from app.services.llm import LLMService
import json
import asyncio
from typing import Dict, Set
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
# Connection manager for handling multiple WebSocket connections
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, list] = {} # Store chat history
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
if client_id not in self.user_sessions:
self.user_sessions[client_id] = []
logger.info(f"Client {client_id} connected")
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
logger.info(f"Client {client_id} disconnected")
async def send_message(self, client_id: str, message: dict):
if client_id in self.active_connections:
await self.active_connections[client_id].send_json(message)
async def broadcast(self, message: dict, exclude: Set[str] = None):
exclude = exclude or set()
for client_id, connection in self.active_connections.items():
if client_id not in exclude:
await connection.send_json(message)
def add_to_history(self, client_id: str, role: str, content: str):
if client_id in self.user_sessions:
self.user_sessions[client_id].append({
"role": role,
"content": content,
"timestamp": datetime.utcnow().isoformat()
})
# Keep only last 50 messages
self.user_sessions[client_id] = self.user_sessions[client_id][-50:]
manager = ConnectionManager()
@router.websocket("/chat/{client_id}")
async def websocket_chat(
websocket: WebSocket,
client_id: str,
llm_service: LLMService = Depends(lambda: LLMService(httpx.AsyncClient()))
):
"""
WebSocket endpoint for real-time chat
"""
await manager.connect(websocket, client_id)
try:
# Send welcome message
await manager.send_message(client_id, {
"type": "system",
"message": "Connected to LLM chat",
"client_id": client_id
})
while True:
# Receive message from client
data = await websocket.receive_json()
message_type = data.get("type", "message")
if message_type == "message":
await handle_chat_message(client_id, data, llm_service)
elif message_type == "typing":
# Broadcast typing indicator
await manager.broadcast({
"type": "typing",
"client_id": client_id
}, exclude={client_id})
elif message_type == "history":
# Send chat history
await manager.send_message(client_id, {
"type": "history",
"messages": manager.user_sessions.get(client_id, [])
})
except WebSocketDisconnect:
manager.disconnect(client_id)
await manager.broadcast({
"type": "user_left",
"client_id": client_id
})
except Exception as e:
logger.error(f"WebSocket error for client {client_id}: {e}")
manager.disconnect(client_id)
async def handle_chat_message(client_id: str, data: dict, llm_service: LLMService):
"""Handle incoming chat message"""
content = data.get("content", "")
# Add to history
manager.add_to_history(client_id, "user", content)
# Echo the user message
await manager.send_message(client_id, {
"type": "message",
"role": "user",
"content": content,
"timestamp": datetime.utcnow().isoformat()
})
# Send typing indicator
await manager.send_message(client_id, {
"type": "assistant_typing",
"status": "start"
})
try:
# Get conversation context
history = manager.user_sessions.get(client_id, [])[-10:] # Last 10 messages
context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history])
# Stream LLM response
full_response = ""
async for token in llm_service.complete(
prompt=f"{context}\nassistant:",
stream=True
):
full_response += token
# Send each token
await manager.send_message(client_id, {
"type": "token",
"content": token
})
# Small delay to prevent overwhelming
await asyncio.sleep(0.01)
# Add to history
manager.add_to_history(client_id, "assistant", full_response)
# Send completion
await manager.send_message(client_id, {
"type": "assistant_typing",
"status": "complete"
})
except Exception as e:
await manager.send_message(client_id, {
"type": "error",
"message": f"Error generating response: {str(e)}"
})WebSocket Client Implementation
// TypeScript WebSocket client
interface Message {
type: 'message' | 'token' | 'error' | 'system' | 'typing' | 'history';
content?: string;
role?: 'user' | 'assistant';
timestamp?: string;
messages?: any[];
}
class WebSocketChat {
private ws: WebSocket | null = null;
private clientId: string;
private reconnectAttempts = 0;
private maxReconnectAttempts = 5;
private reconnectDelay = 1000;
constructor(clientId: string) {
this.clientId = clientId;
}
connect(): Promise<void> {
return new Promise((resolve, reject) => {
const wsUrl = `ws://localhost:8000/ws/chat/${this.clientId}`;
this.ws = new WebSocket(wsUrl);
this.ws.onopen = () => {
console.log('WebSocket connected');
this.reconnectAttempts = 0;
resolve();
};
this.ws.onmessage = (event) => {
const message: Message = JSON.parse(event.data);
this.handleMessage(message);
};
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
reject(error);
};
this.ws.onclose = () => {
console.log('WebSocket disconnected');
this.handleReconnect();
};
});
}
private handleMessage(message: Message) {
switch (message.type) {
case 'message':
console.log(`${message.role}: ${message.content}`);
break;
case 'token':
// Handle streaming token
process.stdout.write(message.content || '');
break;
case 'error':
console.error('Error:', message.content);
break;
case 'system':
console.log('System:', message.content);
break;
case 'typing':
console.log('User is typing...');
break;
case 'history':
console.log('Chat history:', message.messages);
break;
}
}
private async handleReconnect() {
if (this.reconnectAttempts < this.maxReconnectAttempts) {
this.reconnectAttempts++;
console.log(`Reconnecting... Attempt ${this.reconnectAttempts}`);
await new Promise(resolve =>
setTimeout(resolve, this.reconnectDelay * this.reconnectAttempts)
);
try {
await this.connect();
} catch (error) {
console.error('Reconnection failed:', error);
}
}
}
sendMessage(content: string) {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({
type: 'message',
content
}));
}
}
sendTypingIndicator() {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({
type: 'typing'
}));
}
}
requestHistory() {
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({
type: 'history'
}));
}
}
disconnect() {
if (this.ws) {
this.ws.close();
}
}
}
// Usage
const chat = new WebSocketChat('user123');
await chat.connect();
// Send a message
chat.sendMessage('Hello, how are you?');
// Request chat history
chat.requestHistory();Request Validation with Pydantic
Robust input validation and error handling
Advanced Validation Models
# app/models/validation.py
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List, Literal, Union
import re
from datetime import datetime
class ChatMessageStrict(BaseModel):
"""Strict validation for chat messages"""
role: Literal["system", "user", "assistant"]
content: str = Field(..., min_length=1, max_length=10000)
name: Optional[str] = Field(None, regex="^[a-zA-Z0-9_-]+$", max_length=64)
@validator('content')
def validate_content(cls, v):
# Remove potentially harmful content
if re.search(r'<script|javascript:|onerror=', v, re.IGNORECASE):
raise ValueError("Content contains potentially harmful code")
return v.strip()
class AdvancedChatRequest(BaseModel):
"""Advanced chat request with comprehensive validation"""
messages: List[ChatMessageStrict] = Field(..., min_items=1, max_items=50)
model: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_.]+$")
temperature: float = Field(0.7, ge=0, le=2)
max_tokens: int = Field(1000, ge=1, le=4000)
top_p: Optional[float] = Field(None, ge=0, le=1)
frequency_penalty: Optional[float] = Field(None, ge=-2, le=2)
presence_penalty: Optional[float] = Field(None, ge=-2, le=2)
stop: Optional[Union[str, List[str]]] = Field(None)
user_id: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_]+$")
session_id: Optional[str] = Field(None, regex="^[a-zA-Z0-9-_]+$")
@root_validator
def validate_token_limits(cls, values):
messages = values.get('messages', [])
max_tokens = values.get('max_tokens', 1000)
# Estimate prompt tokens (rough calculation)
prompt_tokens = sum(len(msg.content.split()) * 1.3 for msg in messages)
if prompt_tokens + max_tokens > 4000:
raise ValueError(
f"Combined prompt ({int(prompt_tokens)}) and max_tokens ({max_tokens}) "
f"exceeds model limit"
)
return values
@validator('stop')
def validate_stop_sequences(cls, v):
if isinstance(v, list) and len(v) > 4:
raise ValueError("Maximum 4 stop sequences allowed")
return v
class Config:
schema_extra = {
"example": {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
"temperature": 0.7,
"max_tokens": 150
}
}
class ValidationError(BaseModel):
"""Structured validation error response"""
loc: List[Union[str, int]]
msg: str
type: str
ctx: Optional[dict] = None
class ValidationErrorResponse(BaseModel):
"""Response for validation errors"""
detail: List[ValidationError]
error_type: str = "validation_error"
timestamp: datetime = Field(default_factory=datetime.utcnow)
# Custom validators
def validate_api_key(api_key: str) -> str:
"""Validate API key format"""
if not re.match(r'^pk_[a-zA-Z0-9]{32,}$', api_key):
raise ValueError("Invalid API key format")
return api_key
def validate_model_name(model: str) -> str:
"""Validate model name against allowed list"""
allowed_models = {
"gpt-3.5-turbo",
"gpt-4",
"claude-3-sonnet",
"claude-3-opus",
"llama-2-70b"
}
if model not in allowed_models:
raise ValueError(f"Model '{model}' not supported. Allowed: {allowed_models}")
return modelValidation Middleware
# app/middleware/validation.py
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
import logging
logger = logging.getLogger(__name__)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Custom handler for validation errors"""
# Log validation error
logger.warning(
f"Validation error on {request.method} {request.url.path}: "
f"{exc.errors()}"
)
# Format errors for response
formatted_errors = []
for error in exc.errors():
formatted_errors.append({
"loc": error["loc"],
"msg": error["msg"],
"type": error["type"],
"ctx": error.get("ctx")
})
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": formatted_errors,
"error_type": "validation_error",
"path": str(request.url.path),
"method": request.method
}
)
# Input sanitization
from html import escape
import bleach
class InputSanitizer:
"""Sanitize user inputs"""
ALLOWED_TAGS = ['b', 'i', 'u', 'code', 'pre']
ALLOWED_ATTRIBUTES = {}
@staticmethod
def sanitize_text(text: str) -> str:
"""Sanitize text input"""
# Basic HTML escaping
text = escape(text)
# Allow some safe tags
text = bleach.clean(
text,
tags=InputSanitizer.ALLOWED_TAGS,
attributes=InputSanitizer.ALLOWED_ATTRIBUTES,
strip=True
)
# Remove any remaining suspicious patterns
suspicious_patterns = [
r'javascript:',
r'vbscript:',
r'onload=',
r'onerror=',
r'<iframe',
r'<object',
r'<embed'
]
for pattern in suspicious_patterns:
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
return text.strip()
# Request size limiting
async def limit_request_size(request: Request, call_next):
"""Limit request body size"""
MAX_REQUEST_SIZE = 1_000_000 # 1MB
if request.headers.get("content-length"):
content_length = int(request.headers["content-length"])
if content_length > MAX_REQUEST_SIZE:
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
"detail": f"Request body too large. Maximum size: {MAX_REQUEST_SIZE} bytes"
}
)
response = await call_next(request)
return responseAuthentication & Rate Limiting
Securing your API with auth and rate limits
JWT Authentication
# app/services/auth.py
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
import redis
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
class AuthService:
"""Handle authentication and authorization"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None):
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
async def verify_token(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify JWT token"""
token = credentials.credentials
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if token is blacklisted
if self.redis_client.get(f"blacklist:{token}"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked"
)
return {"user_id": user_id, "token": token}
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def revoke_token(self, token: str):
"""Add token to blacklist"""
# Token will be blacklisted until it expires
ttl = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
self.redis_client.setex(f"blacklist:{token}", ttl, "1")
# API Key authentication
class APIKeyAuth:
"""API Key authentication"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
async def verify_api_key(self, api_key: str = Depends(security)):
"""Verify API key"""
key = api_key.credentials
# Check if key exists and is valid
key_data = self.redis_client.hgetall(f"api_key:{key}")
if not key_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
# Check if key is active
if key_data.get(b"status") != b"active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API key is not active"
)
# Update last used timestamp
self.redis_client.hset(f"api_key:{key}", "last_used", datetime.utcnow().isoformat())
return {
"api_key": key,
"user_id": key_data.get(b"user_id", b"").decode(),
"rate_limit": int(key_data.get(b"rate_limit", b"60"))
}Rate Limiting Implementation
# app/middleware/rate_limit.py
from fastapi import Request, Response, HTTPException, status
from typing import Callable, Optional
import time
import redis
from datetime import datetime, timedelta
class RateLimiter:
"""Token bucket rate limiter using Redis"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
async def check_rate_limit(
self,
key: str,
max_requests: int,
window_seconds: int
) -> tuple[bool, dict]:
"""
Check if request is within rate limit
Returns (allowed, info)
"""
now = time.time()
window_start = now - window_seconds
pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcount(key, window_start, now)
pipe.expire(key, window_seconds + 1)
results = pipe.execute()
request_count = results[2]
info = {
"limit": max_requests,
"remaining": max(0, max_requests - request_count),
"reset": int(now + window_seconds)
}
return request_count <= max_requests, info
class RateLimitMiddleware:
"""Rate limiting middleware"""
def __init__(self, redis_client: redis.Redis):
self.limiter = RateLimiter(redis_client)
# Different rate limits for different endpoints
self.limits = {
"/api/v1/chat/completions": (60, 60), # 60 requests per minute
"/api/v1/stream/chat": (30, 60), # 30 requests per minute
"/ws/chat": (10, 60), # 10 connections per minute
"default": (100, 60) # Default: 100 per minute
}
async def __call__(self, request: Request, call_next: Callable):
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get identifier (API key or IP)
identifier = self.get_identifier(request)
# Get rate limit for endpoint
limit, window = self.limits.get(request.url.path, self.limits["default"])
# Check rate limit
key = f"rate_limit:{identifier}:{request.url.path}"
allowed, info = await self.limiter.check_rate_limit(key, limit, window)
# Add rate limit headers
response = Response()
response.headers["X-RateLimit-Limit"] = str(info["limit"])
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
response.headers["X-RateLimit-Reset"] = str(info["reset"])
if not allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded",
headers={
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(info["reset"]),
"Retry-After": str(info["reset"] - int(time.time()))
}
)
# Process request
response = await call_next(request)
# Copy rate limit headers to response
response.headers["X-RateLimit-Limit"] = str(info["limit"])
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
response.headers["X-RateLimit-Reset"] = str(info["reset"])
return response
def get_identifier(self, request: Request) -> str:
"""Get identifier for rate limiting"""
# Try to get API key
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return f"key:{auth_header[7:]}"
# Fall back to IP address
return f"ip:{request.client.host}"
# Usage-based rate limiting
class UsageBasedLimiter:
"""Rate limit based on token usage"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
async def check_token_limit(
self,
user_id: str,
tokens_requested: int,
daily_limit: int = 100000
) -> tuple[bool, dict]:
"""Check if user has tokens remaining"""
today = datetime.utcnow().strftime("%Y-%m-%d")
key = f"usage:{user_id}:{today}"
# Get current usage
current_usage = int(self.redis.get(key) or 0)
if current_usage + tokens_requested > daily_limit:
return False, {
"daily_limit": daily_limit,
"used": current_usage,
"remaining": daily_limit - current_usage,
"requested": tokens_requested
}
# Increment usage
self.redis.incrby(key, tokens_requested)
self.redis.expire(key, 86400) # Expire after 24 hours
return True, {
"daily_limit": daily_limit,
"used": current_usage + tokens_requested,
"remaining": daily_limit - (current_usage + tokens_requested)
}Always use HTTPS in production and store secrets securely. Consider implementing OAuth2 for more complex authentication scenarios.
Background Tasks for Async Processing
Handling long-running tasks without blocking responses
FastAPI Background Tasks
# app/routers/async_processing.py
from fastapi import APIRouter, BackgroundTasks, Depends
from app.models import ChatRequest
import uuid
from typing import Dict
import asyncio
router = APIRouter()
# In-memory task storage (use Redis in production)
task_storage: Dict[str, dict] = {}
async def process_llm_task(task_id: str, request: ChatRequest):
"""Background task for processing LLM request"""
try:
# Update task status
task_storage[task_id]["status"] = "processing"
task_storage[task_id]["started_at"] = datetime.utcnow()
# Simulate LLM processing
await asyncio.sleep(2) # Replace with actual LLM call
# Get LLM response
llm_service = LLMService(httpx.AsyncClient())
result = await llm_service.complete(
prompt=request.messages[-1].content,
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens
)
# Update task with result
task_storage[task_id].update({
"status": "completed",
"result": result,
"completed_at": datetime.utcnow()
})
except Exception as e:
task_storage[task_id].update({
"status": "failed",
"error": str(e),
"failed_at": datetime.utcnow()
})
@router.post("/async-chat")
async def create_async_chat(
request: ChatRequest,
background_tasks: BackgroundTasks
):
"""Submit chat request for async processing"""
# Generate task ID
task_id = str(uuid.uuid4())
# Store initial task data
task_storage[task_id] = {
"id": task_id,
"status": "pending",
"created_at": datetime.utcnow(),
"request": request.dict()
}
# Add to background tasks
background_tasks.add_task(process_llm_task, task_id, request)
return {
"task_id": task_id,
"status": "pending",
"check_url": f"/api/v1/tasks/{task_id}"
}
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
"""Check status of async task"""
if task_id not in task_storage:
raise HTTPException(
status_code=404,
detail="Task not found"
)
task = task_storage[task_id]
# Calculate duration if completed
if task["status"] == "completed" and "started_at" in task:
duration = (task["completed_at"] - task["started_at"]).total_seconds()
task["duration_seconds"] = duration
return taskCelery Integration
# app/celery_app.py
from celery import Celery
from app.config import settings
import httpx
import asyncio
# Create Celery app
celery_app = Celery(
"llm_tasks",
broker=settings.REDIS_URL,
backend=settings.REDIS_URL
)
# Configure Celery
celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
timezone='UTC',
enable_utc=True,
result_expires=3600, # Results expire after 1 hour
task_soft_time_limit=300, # 5 minutes soft limit
task_time_limit=600, # 10 minutes hard limit
)
@celery_app.task(bind=True, name="process_llm_request")
def process_llm_request(self, prompt: str, **kwargs):
"""Celery task for processing LLM requests"""
try:
# Update task state
self.update_state(
state='PROCESSING',
meta={'current': 'Calling LLM API'}
)
# Make synchronous call (Celery doesn't support async)
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{settings.LLM_API_BASE_URL}/chat/completions",
headers={
"Authorization": f"Bearer {settings.LLM_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": kwargs.get("model", settings.LLM_MODEL),
"messages": [{"role": "user", "content": prompt}],
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 1000)
}
)
response.raise_for_status()
result = response.json()
return {
'success': True,
'response': result['choices'][0]['message']['content'],
'usage': result.get('usage', {}),
'model': result.get('model')
}
except Exception as e:
# Log error and re-raise
self.update_state(
state='FAILURE',
meta={'error': str(e)}
)
raise
# Batch processing task
@celery_app.task(name="batch_process_prompts")
def batch_process_prompts(prompts: List[str], **kwargs):
"""Process multiple prompts in batch"""
results = []
for i, prompt in enumerate(prompts):
# Update progress
celery_app.current_task.update_state(
state='PROGRESS',
meta={
'current': i + 1,
'total': len(prompts),
'percentage': ((i + 1) / len(prompts)) * 100
}
)
# Process each prompt
result = process_llm_request.apply_async(
args=[prompt],
kwargs=kwargs
).get()
results.append(result)
return results
# FastAPI endpoint for Celery tasks
@router.post("/celery/submit")
async def submit_celery_task(request: ChatRequest):
"""Submit task to Celery"""
# Submit to Celery
task = process_llm_request.apply_async(
args=[request.messages[-1].content],
kwargs={
"model": request.model,
"temperature": request.temperature,
"max_tokens": request.max_tokens
}
)
return {
"task_id": task.id,
"status": "submitted",
"check_url": f"/api/v1/celery/status/{task.id}"
}
@router.get("/celery/status/{task_id}")
async def get_celery_task_status(task_id: str):
"""Get Celery task status"""
task = process_llm_request.AsyncResult(task_id)
if task.state == 'PENDING':
response = {
'state': task.state,
'status': 'Task is waiting to be processed'
}
elif task.state == 'PROCESSING':
response = {
'state': task.state,
'current': task.info.get('current', ''),
'status': 'Task is being processed'
}
elif task.state == 'SUCCESS':
response = {
'state': task.state,
'result': task.result,
'status': 'Task completed successfully'
}
else: # FAILURE
response = {
'state': task.state,
'error': str(task.info),
'status': 'Task failed'
}
return responseError Handling and Retries
Robust error handling for production systems
Comprehensive Error Handling
# app/exceptions.py
from fastapi import Request, status
from fastapi.responses import JSONResponse
from typing import Union
import traceback
import logging
logger = logging.getLogger(__name__)
class LLMException(Exception):
"""Base exception for LLM-related errors"""
def __init__(self, message: str, status_code: int = 500, details: dict = None):
self.message = message
self.status_code = status_code
self.details = details or {}
super().__init__(self.message)
class RateLimitException(LLMException):
"""Rate limit exceeded"""
def __init__(self, message: str = "Rate limit exceeded", retry_after: int = 60):
super().__init__(message, status_code=429)
self.retry_after = retry_after
class ModelNotFoundException(LLMException):
"""Model not found"""
def __init__(self, model: str):
super().__init__(
f"Model '{model}' not found",
status_code=404,
details={"model": model}
)
class TokenLimitException(LLMException):
"""Token limit exceeded"""
def __init__(self, requested: int, limit: int):
super().__init__(
f"Token limit exceeded. Requested: {requested}, Limit: {limit}",
status_code=400,
details={"requested": requested, "limit": limit}
)
# Exception handlers
async def llm_exception_handler(request: Request, exc: LLMException):
"""Handle LLM-specific exceptions"""
logger.error(
f"LLM error on {request.method} {request.url.path}: "
f"{exc.message} - Details: {exc.details}"
)
response = JSONResponse(
status_code=exc.status_code,
content={
"error": exc.message,
"details": exc.details,
"type": type(exc).__name__,
"path": str(request.url.path)
}
)
if isinstance(exc, RateLimitException):
response.headers["Retry-After"] = str(exc.retry_after)
return response
async def general_exception_handler(request: Request, exc: Exception):
"""Handle unexpected exceptions"""
logger.error(
f"Unexpected error on {request.method} {request.url.path}: "
f"{str(exc)}\n{traceback.format_exc()}"
)
# Don't expose internal errors in production
if settings.DEBUG:
content = {
"error": "Internal server error",
"detail": str(exc),
"type": type(exc).__name__,
"traceback": traceback.format_exc().split("\n")
}
else:
content = {
"error": "Internal server error",
"detail": "An unexpected error occurred",
"request_id": request.state.request_id if hasattr(request.state, 'request_id') else None
}
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=content
)
# Retry mechanism
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log
)
def create_retry_decorator(max_attempts: int = 3):
"""Create a retry decorator with exponential backoff"""
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((httpx.HTTPError, asyncio.TimeoutError)),
before_sleep=before_sleep_log(logger, logging.WARNING)
)
# Circuit breaker pattern
from typing import Callable
import time
class CircuitBreaker:
"""Circuit breaker for external service calls"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type = Exception
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.failure_count = 0
self.last_failure_time = None
self.state = "closed" # closed, open, half-open
async def call(self, func: Callable, *args, **kwargs):
"""Execute function with circuit breaker"""
if self.state == "open":
if time.time() - self.last_failure_time > self.recovery_timeout:
self.state = "half-open"
else:
raise LLMException(
"Circuit breaker is open - service temporarily unavailable",
status_code=503
)
try:
result = await func(*args, **kwargs)
if self.state == "half-open":
self.state = "closed"
self.failure_count = 0
return result
except self.expected_exception as e:
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = "open"
logger.error(f"Circuit breaker opened after {self.failure_count} failures")
raise
# Usage in service
class ResilientLLMService:
"""LLM service with retry and circuit breaker"""
def __init__(self):
self.circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=60,
expected_exception=httpx.HTTPError
)
self.retry_decorator = create_retry_decorator(max_attempts=3)
@create_retry_decorator(max_attempts=3)
async def complete_with_retry(self, prompt: str, **kwargs):
"""Complete with automatic retry"""
async def _make_request():
# Your LLM API call here
pass
return await self.circuit_breaker.call(_make_request)Production Deployment
Deploying FastAPI with Uvicorn and Gunicorn
Production Server Configuration
# gunicorn.conf.py
import multiprocessing
import os
# Server socket
bind = f"0.0.0.0:{os.environ.get('PORT', '8000')}"
backlog = 2048
# Worker processes
workers = multiprocessing.cpu_count() * 2 + 1
worker_class = "uvicorn.workers.UvicornWorker"
worker_connections = 1000
keepalive = 5
max_requests = 1000
max_requests_jitter = 50
# Timeouts
timeout = 30
graceful_timeout = 30
# Logging
accesslog = "-"
errorlog = "-"
loglevel = "info"
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(D)s'
# Process naming
proc_name = 'llm-api'
# Server mechanics
daemon = False
pidfile = None
user = None
group = None
tmp_upload_dir = None
# SSL (if needed)
# keyfile = "/path/to/keyfile"
# certfile = "/path/to/certfile"
# Server hooks
def when_ready(server):
server.log.info("Server is ready. Spawning workers")
def worker_int(worker):
worker.log.info("Worker received INT or QUIT signal")
def pre_fork(server, worker):
server.log.info(f"Worker spawned (pid: {worker.pid})")
def pre_exec(server):
server.log.info("Forked child, re-executing.")
def worker_abort(worker):
worker.log.info("Worker received SIGABRT signal")Docker Configuration
# Dockerfile
FROM python:3.11-slim as builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy requirements
COPY requirements.txt .
# Install Python dependencies
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt
# Production stage
FROM python:3.11-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd -m -u 1000 appuser
WORKDIR /app
# Copy Python dependencies from builder
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy application code
COPY --chown=appuser:appuser . .
# Switch to non-root user
USER appuser
# Environment variables
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PORT=8000
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
CMD curl -f http://localhost:$PORT/health || exit 1
# Expose port
EXPOSE $PORT
# Run with gunicorn
CMD ["gunicorn", "app.main:app", "-c", "gunicorn.conf.py"]Docker Compose for Development
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- REDIS_URL=redis://redis:6379
- DATABASE_URL=postgresql://user:pass@postgres:5432/llm_db
- LLM_API_KEY=${LLM_API_KEY}
- SECRET_KEY=${SECRET_KEY}
depends_on:
- redis
- postgres
volumes:
- ./app:/app/app # For development hot reload
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
postgres:
image: postgres:15-alpine
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: pass
POSTGRES_DB: llm_db
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
- ./ssl:/etc/nginx/ssl
depends_on:
- api
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus_data:/prometheus
grafana:
image: grafana/grafana
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana_data:/var/lib/grafana
volumes:
redis_data:
postgres_data:
prometheus_data:
grafana_data:Nginx Configuration
# nginx.conf
events {
worker_connections 1024;
}
http {
upstream fastapi_backend {
least_conn;
server api:8000 max_fails=3 fail_timeout=30s;
keepalive 32;
}
# Rate limiting
limit_req_zone $binary_remote_addr zone=general:10m rate=10r/s;
limit_req_zone $binary_remote_addr zone=api:10m rate=100r/m;
# Connection limiting
limit_conn_zone $binary_remote_addr zone=addr:10m;
server {
listen 80;
server_name api.yourdomain.com;
return 301 https://$server_name$request_uri;
}
server {
listen 443 ssl http2;
server_name api.yourdomain.com;
# SSL configuration
ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/key.pem;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384;
ssl_prefer_server_ciphers off;
# Security headers
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header X-XSS-Protection "1; mode=block" always;
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
# Timeouts
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
# Buffer sizes
client_body_buffer_size 1M;
client_max_body_size 10M;
# Rate limiting
limit_req zone=general burst=20 nodelay;
limit_conn addr 10;
location / {
proxy_pass http://fastapi_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Disable buffering for SSE
proxy_buffering off;
proxy_cache off;
# Timeouts for streaming
proxy_read_timeout 3600s;
keepalive_timeout 3600s;
}
location /api/ {
limit_req zone=api burst=5 nodelay;
proxy_pass http://fastapi_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
location /ws/ {
proxy_pass http://fastapi_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
# WebSocket specific
proxy_read_timeout 3600s;
proxy_send_timeout 3600s;
}
location /health {
proxy_pass http://fastapi_backend;
access_log off;
}
}
}Scaling Strategies
Horizontal and vertical scaling for high traffic
Kubernetes Deployment
# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: fastapi-llm
labels:
app: fastapi-llm
spec:
replicas: 3
selector:
matchLabels:
app: fastapi-llm
template:
metadata:
labels:
app: fastapi-llm
spec:
containers:
- name: fastapi
image: your-registry/fastapi-llm:latest
ports:
- containerPort: 8000
env:
- name: REDIS_URL
value: redis://redis-service:6379
- name: LLM_API_KEY
valueFrom:
secretKeyRef:
name: llm-secrets
key: api-key
resources:
requests:
memory: "512Mi"
cpu: "500m"
limits:
memory: "1Gi"
cpu: "1000m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: fastapi-service
spec:
selector:
app: fastapi-llm
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: fastapi-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: fastapi-llm
minReplicas: 3
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80Horizontal Scaling
Load balancer distribution
Auto-scaling based on metrics
Session affinity for WebSockets
Shared Redis for state
Vertical Scaling
Increase worker processes
Optimize memory usage
Use connection pooling
Enable response caching
Monitoring & Logging
Observability for production FastAPI applications
Prometheus Metrics
# app/monitoring.py
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import Counter, Histogram, Gauge, Info
import time
# Custom metrics
llm_requests_total = Counter(
'llm_requests_total',
'Total LLM API requests',
['model', 'endpoint', 'status']
)
llm_request_duration = Histogram(
'llm_request_duration_seconds',
'LLM request duration',
['model', 'endpoint'],
buckets=[0.1, 0.5, 1, 2, 5, 10, 30]
)
llm_tokens_used = Counter(
'llm_tokens_used_total',
'Total tokens used',
['model', 'type'] # type: prompt, completion
)
active_websocket_connections = Gauge(
'active_websocket_connections',
'Number of active WebSocket connections'
)
app_info = Info('fastapi_app', 'Application info')
app_info.info({
'version': '1.0.0',
'environment': settings.ENVIRONMENT
})
# Instrumentator setup
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=True,
should_respect_env_var=True,
should_instrument_requests_inprogress=True,
excluded_handlers=[".*health.*", ".*metrics.*"],
inprogress_name="fastapi_inprogress",
inprogress_labels=True
)
# Custom metric middleware
async def record_llm_metrics(request: Request, call_next):
"""Record custom LLM metrics"""
start_time = time.time()
# Add request ID for tracing
request.state.request_id = str(uuid.uuid4())
response = await call_next(request)
# Record metrics
if request.url.path.startswith("/api/v1/"):
duration = time.time() - start_time
endpoint = request.url.path
model = "unknown" # Extract from request/response
llm_requests_total.labels(
model=model,
endpoint=endpoint,
status=response.status_code
).inc()
llm_request_duration.labels(
model=model,
endpoint=endpoint
).observe(duration)
return response
# Structured logging
import structlog
from pythonjsonlogger import jsonlogger
def setup_logging():
"""Configure structured logging"""
# Configure structlog
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
# Configure standard library logging
logHandler = logging.StreamHandler()
formatter = jsonlogger.JsonFormatter(
fmt="%(timestamp)s %(level)s %(name)s %(message)s",
timestamp=True
)
logHandler.setFormatter(formatter)
logging.basicConfig(
level=settings.LOG_LEVEL,
handlers=[logHandler]
)
# Request ID middleware
from asgi_correlation_id import CorrelationIdMiddleware
from asgi_correlation_id.context import correlation_id
app.add_middleware(
CorrelationIdMiddleware,
header_name='X-Request-ID',
generator=lambda: str(uuid.uuid4())
)
# Logging middleware
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log all requests with context"""
logger = structlog.get_logger()
# Log request
await logger.ainfo(
"request_started",
request_id=correlation_id.get(),
method=request.method,
path=request.url.path,
client=request.client.host if request.client else None
)
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
# Log response
await logger.ainfo(
"request_completed",
request_id=correlation_id.get(),
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration=duration
)
return responseUse centralized logging (ELK stack, Datadog, etc.) and distributed tracing (Jaeger, Zipkin) for production environments.
References
Additional resources and documentation
Official Documentation
FastAPI's async capabilities make it ideal for LLM streaming applications. Always test your streaming implementations across different deployment environments.