Part 11 of 12
๐ค Ghostwritten by Claude Opus 4.5 ยท Curated by Tom Hundley
This article was written by Claude Opus 4.5 and curated for publication by Tom Hundley.
Your demo works perfectly. Now make it survive 10,000 users.
Every RAG demo looks the same. Developer asks a question. System retrieves context. LLM generates a thoughtful response. The audience applauds. The VP of Engineering asks when it can go live.
Then reality hits.
The demo that worked beautifully with 100 documents chokes on 100,000. The 2-second response time becomes 15 seconds under load. The embedding costs that seemed negligible in development trigger budget alerts in production. The clever prompt that impressed stakeholders occasionally leaks customer data.
If you have followed this series through LangChain (Part 2), LlamaIndex (Part 3), and the platform-specific implementations, you have built RAG systems that work. This article teaches you how to make them work at scale.
The following table represents what we have observed across dozens of production RAG deployments:
| Metric | Demo Environment | Production Reality |
|---|---|---|
| Document count | 100-1,000 | 100,000-10,000,000 |
| Queries per day | 10-50 | 10,000-1,000,000 |
| Latency tolerance | "A few seconds is fine" | P95 < 2 seconds required |
| Uptime requirement | "Just restart it" | 99.9% SLA |
| Cost visibility | "We'll optimize later" | CFO reviews monthly |
| Security requirements | "We trust our users" | SOC 2, HIPAA, audit logs |
The techniques that work in development often break at scale. Synchronous embedding generation blocks requests. Uncached queries hammer your vector database. Naive prompts leak context across users. This article addresses each of these challenges.
Before optimizing, define your targets. Here are industry benchmarks for production RAG systems:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ PRODUCTION RAG LATENCY TARGETS โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ COMPONENT โ P50 โ P95 โ P99 โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค โ
โ โ Query Embedding โ 50ms โ 100ms โ 200ms โ โ
โ โ Vector Search โ 20ms โ 50ms โ 100ms โ โ
โ โ Document Fetch โ 10ms โ 30ms โ 50ms โ โ
โ โ LLM Generation โ 1-3s โ 5s โ 8s โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ TOTAL (no cache) โ 1.5-3.5s โ 5.2s โ 8.4s โ โ
โ โ TOTAL (with cache) โ 100-500ms โ 1.5s โ 3s โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโStreaming changes the perception. While total time-to-complete may be 3-5 seconds, streaming the LLM response so users see the first token in 500ms dramatically improves perceived performance.
The most common production surprise is cost. Here is what a naive RAG implementation costs at scale:
# Cost calculation for unoptimized RAG
# Assumptions
daily_queries = 50_000
avg_query_tokens = 30 # User query
avg_context_tokens = 4_000 # Retrieved chunks
avg_response_tokens = 500 # LLM output
embedding_dimensions = 1536
# OpenAI pricing (as of late 2025)
embedding_cost_per_1m = 0.02 # text-embedding-3-small
gpt4o_input_per_1m = 2.50
gpt4o_output_per_1m = 10.00
# Daily costs
embedding_cost = (daily_queries * avg_query_tokens / 1_000_000) * embedding_cost_per_1m
# $0.03/day for embeddings
llm_input_cost = (daily_queries * (avg_query_tokens + avg_context_tokens) / 1_000_000) * gpt4o_input_per_1m
# $503.75/day for LLM input
llm_output_cost = (daily_queries * avg_response_tokens / 1_000_000) * gpt4o_output_per_1m
# $250.00/day for LLM output
total_daily = embedding_cost + llm_input_cost + llm_output_cost
# $753.78/day = $22,613/month
# With optimization (60% cache hit, smaller model for simple queries)
optimized_daily = total_daily * 0.35
# $263.82/day = $7,915/month
print(f"Unoptimized monthly cost: ${total_daily * 30:,.2f}")
print(f"Optimized monthly cost: ${optimized_daily * 30:,.2f}")
print(f"Annual savings: ${(total_daily - optimized_daily) * 365:,.2f}")The difference between naive and optimized implementations can be $15,000-50,000 annually for a single RAG system. The techniques in this article will help you achieve those savings.
Caching is the single most impactful optimization for production RAG. A well-designed cache can reduce costs by 50-70% and cut latency by 80%.
The simplest cache: if someone asked this exact question before, return the previous answer.
import hashlib
import json
import redis
from datetime import timedelta
from typing import Optional, Any
class ExactMatchCache:
"""Cache for exact query matches."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
ttl_hours: int = 24,
prefix: str = "rag:exact:"
):
self.redis = redis.from_url(redis_url)
self.ttl = timedelta(hours=ttl_hours)
self.prefix = prefix
def _make_key(self, query: str, context_hash: str) -> str:
"""Generate cache key from query and context."""
# Include context hash to invalidate cache when documents change
content = f"{query.lower().strip()}:{context_hash}"
query_hash = hashlib.sha256(content.encode()).hexdigest()[:16]
return f"{self.prefix}{query_hash}"
def get(
self,
query: str,
context_hash: str
) -> Optional[dict[str, Any]]:
"""Retrieve cached response if available."""
key = self._make_key(query, context_hash)
cached = self.redis.get(key)
if cached:
return json.loads(cached)
return None
def set(
self,
query: str,
context_hash: str,
response: dict[str, Any]
) -> None:
"""Cache a response."""
key = self._make_key(query, context_hash)
self.redis.setex(
key,
self.ttl,
json.dumps(response)
)
def invalidate_all(self) -> int:
"""Invalidate all cached responses. Returns count deleted."""
keys = self.redis.keys(f"{self.prefix}*")
if keys:
return self.redis.delete(*keys)
return 0
# Usage in RAG pipeline
class CachedRAGPipeline:
def __init__(self, rag_chain, cache: ExactMatchCache):
self.rag_chain = rag_chain
self.cache = cache
self.context_version = "v1" # Bump when documents change
async def query(self, question: str) -> dict:
# Check cache first
cached = self.cache.get(question, self.context_version)
if cached:
return {**cached, "cache_hit": True}
# Generate fresh response
response = await self.rag_chain.ainvoke(question)
# Cache for future requests
self.cache.set(
question,
self.context_version,
{"answer": response["answer"], "sources": response["sources"]}
)
return {**response, "cache_hit": False}Exact match caching works well for:
Limitation: "What's our vacation policy?" and "vacation policy?" are cache misses despite being semantically identical.
Semantic caching uses embeddings to find similar previous queries. If someone asked a sufficiently similar question, return the cached answer.
import numpy as np
from openai import OpenAI
from typing import Optional
import json
import redis
class SemanticCache:
"""Cache that matches semantically similar queries."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
similarity_threshold: float = 0.95,
ttl_hours: int = 24,
prefix: str = "rag:semantic:"
):
self.redis = redis.from_url(redis_url)
self.openai = OpenAI()
self.threshold = similarity_threshold
self.ttl_seconds = ttl_hours * 3600
self.prefix = prefix
self.index_key = f"{prefix}index"
def _embed(self, text: str) -> list[float]:
"""Generate embedding for text."""
response = self.openai.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
def _cosine_similarity(
self,
vec1: list[float],
vec2: list[float]
) -> float:
"""Calculate cosine similarity between two vectors."""
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def get(self, query: str) -> Optional[dict]:
"""Find semantically similar cached response."""
query_embedding = self._embed(query)
# Get all cached entries
cache_keys = self.redis.smembers(self.index_key)
best_match = None
best_similarity = 0.0
for key in cache_keys:
cached_data = self.redis.get(key)
if not cached_data:
continue
entry = json.loads(cached_data)
similarity = self._cosine_similarity(
query_embedding,
entry["embedding"]
)
if similarity > best_similarity and similarity >= self.threshold:
best_similarity = similarity
best_match = entry
if best_match:
return {
"answer": best_match["answer"],
"sources": best_match["sources"],
"similarity": best_similarity,
"original_query": best_match["query"]
}
return None
def set(
self,
query: str,
answer: str,
sources: list[str]
) -> None:
"""Cache a response with its embedding."""
embedding = self._embed(query)
cache_key = f"{self.prefix}{hash(query)}"
entry = {
"query": query,
"answer": answer,
"sources": sources,
"embedding": embedding
}
# Store entry
self.redis.setex(
cache_key,
self.ttl_seconds,
json.dumps(entry)
)
# Add to index
self.redis.sadd(self.index_key, cache_key)Performance consideration: Naive semantic caching requires comparing the query embedding against all cached embeddings. For production, use a proper vector index:
# Production semantic cache with Redis Vector Search
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
class ProductionSemanticCache:
"""Semantic cache using Redis Stack vector search."""
def __init__(
self,
redis_url: str,
index_name: str = "semantic_cache_idx",
similarity_threshold: float = 0.95,
ttl_hours: int = 24
):
self.redis = redis.from_url(redis_url)
self.openai = OpenAI()
self.index_name = index_name
self.threshold = similarity_threshold
self.ttl = ttl_hours * 3600
self._ensure_index()
def _ensure_index(self):
"""Create vector index if it doesn't exist."""
try:
self.redis.ft(self.index_name).info()
except redis.ResponseError:
# Create index
schema = [
TextField("query"),
TextField("answer"),
VectorField(
"embedding",
"HNSW",
{
"TYPE": "FLOAT32",
"DIM": 1536,
"DISTANCE_METRIC": "COSINE"
}
)
]
self.redis.ft(self.index_name).create_index(
schema,
definition=IndexDefinition(
prefix=["cache:"],
index_type=IndexType.HASH
)
)
def get(self, query: str) -> Optional[dict]:
"""Find similar cached response using vector search."""
embedding = self._embed(query)
# KNN query for most similar
q = Query(
f"*=>[KNN 1 @embedding $vec AS score]"
).return_fields(
"query", "answer", "sources", "score"
).dialect(2)
results = self.redis.ft(self.index_name).search(
q,
query_params={"vec": np.array(embedding).astype(np.float32).tobytes()}
)
if results.docs:
doc = results.docs[0]
similarity = 1 - float(doc.score) # Convert distance to similarity
if similarity >= self.threshold:
return {
"answer": doc.answer,
"sources": json.loads(doc.sources),
"similarity": similarity,
"original_query": doc.query
}
return NoneSometimes you want to cache not just the final answer but the entire response context: retrieved documents, confidence scores, and metadata.
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Optional
import json
@dataclass
class CachedRAGResponse:
"""Complete cached RAG response."""
query: str
answer: str
sources: list[dict]
retrieved_chunks: list[dict]
confidence_score: float
model_used: str
tokens_used: int
created_at: str
context_version: str
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, data: str) -> "CachedRAGResponse":
return cls(**json.loads(data))
class FullResponseCache:
"""Cache complete RAG responses with metadata."""
def __init__(
self,
redis_url: str,
ttl_hours: int = 24
):
self.redis = redis.from_url(redis_url)
self.ttl = ttl_hours * 3600
def store(
self,
query: str,
response: CachedRAGResponse,
context_version: str
) -> None:
"""Store complete response."""
key = self._make_key(query, context_version)
self.redis.setex(key, self.ttl, response.to_json())
# Also store in a secondary index for analytics
analytics_key = f"analytics:{datetime.now().strftime('%Y-%m-%d')}"
self.redis.lpush(analytics_key, json.dumps({
"query": query,
"tokens": response.tokens_used,
"model": response.model_used,
"cached": False,
"timestamp": response.created_at
}))
def retrieve(
self,
query: str,
context_version: str
) -> Optional[CachedRAGResponse]:
"""Retrieve cached response."""
key = self._make_key(query, context_version)
data = self.redis.get(key)
if data:
return CachedRAGResponse.from_json(data)
return None
def _make_key(self, query: str, context_version: str) -> str:
content = f"{query.lower().strip()}:{context_version}"
return f"rag:full:{hashlib.sha256(content.encode()).hexdigest()[:16]}"Production systems often combine caching strategies in a tiered approach:
class TieredRAGCache:
"""Multi-tier caching strategy for production RAG."""
def __init__(
self,
redis_url: str,
semantic_threshold: float = 0.95,
context_version: str = "v1"
):
self.exact_cache = ExactMatchCache(redis_url)
self.semantic_cache = ProductionSemanticCache(
redis_url,
similarity_threshold=semantic_threshold
)
self.full_cache = FullResponseCache(redis_url)
self.context_version = context_version
# Metrics
self.hits = {"exact": 0, "semantic": 0, "miss": 0}
async def get(self, query: str) -> tuple[Optional[dict], str]:
"""
Check caches in order of cost.
Returns (response, cache_tier) where tier is 'exact', 'semantic', or 'miss'.
"""
# Tier 1: Exact match (cheapest - no embedding needed)
exact_hit = self.exact_cache.get(query, self.context_version)
if exact_hit:
self.hits["exact"] += 1
return exact_hit, "exact"
# Tier 2: Semantic match (requires embedding)
semantic_hit = self.semantic_cache.get(query)
if semantic_hit:
self.hits["semantic"] += 1
return semantic_hit, "semantic"
# Tier 3: Cache miss
self.hits["miss"] += 1
return None, "miss"
async def set(
self,
query: str,
response: CachedRAGResponse
) -> None:
"""Store response in all cache tiers."""
# Store in exact cache
self.exact_cache.set(
query,
self.context_version,
{"answer": response.answer, "sources": response.sources}
)
# Store in semantic cache
self.semantic_cache.set(
query,
response.answer,
response.sources
)
# Store full response
self.full_cache.store(query, response, self.context_version)
def get_stats(self) -> dict:
"""Return cache hit statistics."""
total = sum(self.hits.values())
if total == 0:
return {"hit_rate": 0, "breakdown": self.hits}
return {
"hit_rate": (self.hits["exact"] + self.hits["semantic"]) / total,
"exact_rate": self.hits["exact"] / total,
"semantic_rate": self.hits["semantic"] / total,
"breakdown": self.hits
}For a production-ready semantic caching solution, consider GPTCache:
from gptcache import cache
from gptcache.adapter import openai as gptcache_openai
from gptcache.embedding import OpenAI as OpenAIEmbedding
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.manager import get_data_manager, CacheBase, VectorBase
def setup_gptcache():
"""Configure GPTCache for RAG responses."""
# Use OpenAI embeddings for similarity
embedding = OpenAIEmbedding()
# Configure storage backends
cache_base = CacheBase("sqlite", sql_url="sqlite:///gptcache.db")
vector_base = VectorBase(
"faiss",
dimension=embedding.dimension,
top_k=1
)
data_manager = get_data_manager(cache_base, vector_base)
# Initialize cache
cache.init(
embedding_func=embedding.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation()
)
return cache
# Usage with OpenAI
cache = setup_gptcache()
# This call will be cached semantically
response = gptcache_openai.ChatCompletion.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "Answer based on the context provided."},
{"role": "user", "content": f"Context: {context}\n\nQuestion: {query}"}
]
)After caching, latency optimization targets each component of the RAG pipeline individually.
Embedding generation is often a bottleneck, especially for long queries or when batch processing is unavailable.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from openai import OpenAI, AsyncOpenAI
from functools import lru_cache
import time
class OptimizedEmbeddingService:
"""Embedding service with latency optimizations."""
def __init__(
self,
model: str = "text-embedding-3-small",
batch_size: int = 100,
max_concurrent: int = 10
):
self.sync_client = OpenAI()
self.async_client = AsyncOpenAI()
self.model = model
self.batch_size = batch_size
self.semaphore = asyncio.Semaphore(max_concurrent)
self.executor = ThreadPoolExecutor(max_workers=4)
# Cache for repeated embeddings
self._cache: dict[str, list[float]] = {}
@lru_cache(maxsize=1000)
def embed_cached(self, text: str) -> tuple[float, ...]:
"""Embed with LRU cache. Returns tuple for hashability."""
response = self.sync_client.embeddings.create(
model=self.model,
input=text
)
return tuple(response.data[0].embedding)
async def embed_async(self, text: str) -> list[float]:
"""Async embedding with rate limiting."""
async with self.semaphore:
response = await self.async_client.embeddings.create(
model=self.model,
input=text
)
return response.data[0].embedding
async def embed_batch_async(
self,
texts: list[str]
) -> list[list[float]]:
"""Batch embed with optimal chunking."""
all_embeddings = []
# Process in batches to stay within API limits
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
async with self.semaphore:
response = await self.async_client.embeddings.create(
model=self.model,
input=batch
)
batch_embeddings = [d.embedding for d in response.data]
all_embeddings.extend(batch_embeddings)
return all_embeddings
def embed_with_timeout(
self,
text: str,
timeout_ms: int = 200
) -> list[float]:
"""Embed with timeout, falling back to cached if available."""
cache_key = text[:100] # Truncate for cache key
if cache_key in self._cache:
return self._cache[cache_key]
try:
future = self.executor.submit(
self.sync_client.embeddings.create,
model=self.model,
input=text
)
response = future.result(timeout=timeout_ms / 1000)
embedding = response.data[0].embedding
# Cache for future use
self._cache[cache_key] = embedding
return embedding
except TimeoutError:
# Return None or raise - depends on your error handling strategy
raise TimeoutError(f"Embedding generation exceeded {timeout_ms}ms")Vector search performance depends heavily on index configuration. Here are optimizations for popular vector stores:
# Pinecone optimization
import pinecone
def create_optimized_pinecone_index(
index_name: str,
dimension: int = 1536,
metric: str = "cosine"
):
"""Create a Pinecone index optimized for latency."""
# Use pod-based index for consistent latency
# Serverless is cheaper but can have cold starts
pinecone.create_index(
name=index_name,
dimension=dimension,
metric=metric,
spec=pinecone.PodSpec(
environment="us-east-1-aws",
pod_type="p1.x1", # Balanced performance
pods=1,
replicas=2, # Replicas for read throughput
shards=1
)
)
def optimized_pinecone_query(
index,
query_vector: list[float],
top_k: int = 5,
namespace: str = "default"
) -> dict:
"""Query Pinecone with latency optimizations."""
return index.query(
vector=query_vector,
top_k=top_k,
namespace=namespace,
include_metadata=True,
include_values=False, # Don't return vectors (saves bandwidth)
)
# pgvector optimization
async def optimize_pgvector_index(conn):
"""Configure pgvector for low-latency search."""
await conn.execute("""
-- Use IVFFlat for faster approximate search
-- lists = sqrt(n) where n is number of vectors
CREATE INDEX IF NOT EXISTS idx_embeddings_ivfflat
ON documents USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
-- Set probes for query time (higher = more accurate but slower)
-- Start with sqrt(lists) and tune based on recall requirements
SET ivfflat.probes = 10;
-- For very large datasets, consider HNSW instead
-- CREATE INDEX idx_embeddings_hnsw
-- ON documents USING hnsw (embedding vector_cosine_ops)
-- WITH (m = 16, ef_construction = 64);
""")
async def fast_pgvector_query(
pool,
query_embedding: list[float],
top_k: int = 5,
threshold: float = 0.7
) -> list[dict]:
"""Optimized pgvector query with connection pooling."""
async with pool.acquire() as conn:
# Set probes for this session
await conn.execute("SET ivfflat.probes = 10")
rows = await conn.fetch("""
SELECT
id,
content,
metadata,
1 - (embedding <=> $1::vector) as similarity
FROM documents
WHERE 1 - (embedding <=> $1::vector) > $3
ORDER BY embedding <=> $1::vector
LIMIT $2
""", query_embedding, top_k, threshold)
return [dict(row) for row in rows]
# Qdrant optimization
from qdrant_client import QdrantClient
from qdrant_client.models import (
VectorParams, Distance, OptimizersConfigDiff,
HnswConfigDiff, SearchParams
)
def create_optimized_qdrant_collection(
client: QdrantClient,
collection_name: str,
dimension: int = 1536
):
"""Create Qdrant collection optimized for low latency."""
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE,
on_disk=False # Keep in memory for speed
),
hnsw_config=HnswConfigDiff(
m=16, # Number of edges per node
ef_construct=100, # Build-time accuracy
full_scan_threshold=10000, # Use index above this size
max_indexing_threads=0, # Use all available cores
on_disk=False
),
optimizers_config=OptimizersConfigDiff(
indexing_threshold=20000,
memmap_threshold=50000
)
)
def fast_qdrant_search(
client: QdrantClient,
collection_name: str,
query_vector: list[float],
top_k: int = 5
) -> list[dict]:
"""Search Qdrant with latency-optimized parameters."""
results = client.search(
collection_name=collection_name,
query_vector=query_vector,
limit=top_k,
search_params=SearchParams(
hnsw_ef=50, # Query-time accuracy (higher = slower but better)
exact=False # Use approximate search
),
with_payload=True,
with_vectors=False # Don't return vectors
)
return [
{
"id": r.id,
"score": r.score,
"content": r.payload.get("content"),
"metadata": r.payload.get("metadata")
}
for r in results
]The LLM is typically the largest latency contributor. Optimize with streaming, model selection, and parallel processing.
from openai import AsyncOpenAI
import anthropic
from typing import AsyncIterator
class OptimizedLLMService:
"""LLM service optimized for latency."""
def __init__(self):
self.openai = AsyncOpenAI()
self.anthropic = anthropic.AsyncAnthropic()
async def stream_response(
self,
system_prompt: str,
context: str,
query: str,
model: str = "gpt-4o"
) -> AsyncIterator[str]:
"""Stream response for perceived low latency."""
if model.startswith("gpt"):
async for chunk in await self.openai.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
],
stream=True,
max_tokens=1000
):
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
elif model.startswith("claude"):
async with self.anthropic.messages.stream(
model=model,
max_tokens=1000,
system=system_prompt,
messages=[
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
) as stream:
async for text in stream.text_stream:
yield text
async def generate_with_timeout(
self,
prompt: str,
model: str = "gpt-4o-mini",
timeout_seconds: float = 10.0
) -> str:
"""Generate with hard timeout."""
try:
response = await asyncio.wait_for(
self.openai.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=500
),
timeout=timeout_seconds
)
return response.choices[0].message.content
except asyncio.TimeoutError:
raise TimeoutError(f"LLM generation exceeded {timeout_seconds}s")
class ModelRouter:
"""Route queries to appropriate models based on complexity."""
def __init__(self):
self.llm = OptimizedLLMService()
# Simple patterns that can use smaller models
self.simple_patterns = [
r"^what is",
r"^define",
r"^list",
r"^when did",
r"^who is",
]
def classify_complexity(self, query: str) -> str:
"""Classify query complexity."""
query_lower = query.lower()
# Check for simple patterns
for pattern in self.simple_patterns:
if re.match(pattern, query_lower):
return "simple"
# Check query length and structure
if len(query.split()) < 10 and "?" in query:
return "simple"
# Complex queries: multi-step reasoning, comparisons, analysis
complex_indicators = [
"compare", "analyze", "explain why", "what are the implications",
"how does", "evaluate", "synthesize"
]
if any(ind in query_lower for ind in complex_indicators):
return "complex"
return "medium"
def select_model(self, query: str, context_length: int) -> str:
"""Select optimal model for query."""
complexity = self.classify_complexity(query)
if complexity == "simple":
return "gpt-4o-mini" # Fast and cheap
elif complexity == "medium":
return "gpt-4o" # Balanced
else:
return "gpt-4o" # Full capabilityWhen you need to search multiple sources, do it in parallel:
import asyncio
from dataclasses import dataclass
@dataclass
class RetrievalResult:
source: str
chunks: list[dict]
latency_ms: float
class ParallelRetriever:
"""Retrieve from multiple sources in parallel."""
def __init__(self, sources: dict):
"""
sources: Dict mapping source names to retriever functions
Example: {"pinecone": pinecone_search, "postgres": pg_search}
"""
self.sources = sources
async def retrieve_all(
self,
query_embedding: list[float],
top_k_per_source: int = 5,
timeout_ms: float = 500
) -> list[RetrievalResult]:
"""Retrieve from all sources in parallel with timeout."""
async def search_source(name: str, searcher) -> RetrievalResult:
start = time.time()
try:
chunks = await asyncio.wait_for(
searcher(query_embedding, top_k_per_source),
timeout=timeout_ms / 1000
)
return RetrievalResult(
source=name,
chunks=chunks,
latency_ms=(time.time() - start) * 1000
)
except asyncio.TimeoutError:
return RetrievalResult(
source=name,
chunks=[],
latency_ms=timeout_ms
)
tasks = [
search_source(name, searcher)
for name, searcher in self.sources.items()
]
results = await asyncio.gather(*tasks)
return results
async def retrieve_fastest(
self,
query_embedding: list[float],
top_k: int = 5,
min_sources: int = 1
) -> list[dict]:
"""Return results from the fastest responding sources."""
async def search_with_timing(name, searcher):
start = time.time()
chunks = await searcher(query_embedding, top_k)
return name, chunks, time.time() - start
# Create tasks
tasks = [
asyncio.create_task(search_with_timing(name, searcher))
for name, searcher in self.sources.items()
]
all_chunks = []
completed = 0
# Return as soon as we have enough results
for coro in asyncio.as_completed(tasks):
name, chunks, latency = await coro
all_chunks.extend(chunks)
completed += 1
if completed >= min_sources and len(all_chunks) >= top_k:
# Cancel remaining tasks
for task in tasks:
task.cancel()
break
# Deduplicate and rank
return self._deduplicate_and_rank(all_chunks, top_k)
def _deduplicate_and_rank(
self,
chunks: list[dict],
top_k: int
) -> list[dict]:
"""Remove duplicates and return top results."""
seen_ids = set()
unique_chunks = []
for chunk in chunks:
chunk_id = chunk.get("id") or hash(chunk.get("content", ""))
if chunk_id not in seen_ids:
seen_ids.add(chunk_id)
unique_chunks.append(chunk)
# Sort by score and return top k
unique_chunks.sort(key=lambda x: x.get("score", 0), reverse=True)
return unique_chunks[:top_k]Pre-compute expensive operations during off-peak hours:
from datetime import datetime, time
import schedule
import threading
class PrecomputationService:
"""Pre-compute embeddings and popular query results."""
def __init__(self, cache: TieredRAGCache, rag_pipeline):
self.cache = cache
self.rag = rag_pipeline
self.popular_queries: list[str] = []
def track_query(self, query: str):
"""Track query for popularity analysis."""
# In production, use a proper analytics system
self.popular_queries.append(query)
async def precompute_popular_queries(self, top_n: int = 100):
"""Pre-warm cache with popular queries."""
# Analyze query frequency
query_counts = {}
for q in self.popular_queries:
normalized = q.lower().strip()
query_counts[normalized] = query_counts.get(normalized, 0) + 1
# Sort by frequency
popular = sorted(
query_counts.items(),
key=lambda x: x[1],
reverse=True
)[:top_n]
# Pre-compute responses
for query, count in popular:
cached = await self.cache.get(query)
if not cached:
print(f"Pre-computing: {query} (asked {count} times)")
response = await self.rag.generate(query)
await self.cache.set(query, response)
async def precompute_document_embeddings(
self,
documents: list[dict],
batch_size: int = 100
):
"""Pre-compute embeddings for new documents."""
embedding_service = OptimizedEmbeddingService()
texts = [doc["content"] for doc in documents]
embeddings = await embedding_service.embed_batch_async(texts)
# Store embeddings with documents
for doc, embedding in zip(documents, embeddings):
doc["embedding"] = embedding
return documents
def schedule_precomputation(self):
"""Schedule pre-computation during off-peak hours."""
def run_precomputation():
asyncio.run(self.precompute_popular_queries())
# Run at 3 AM daily
schedule.every().day.at("03:00").do(run_precomputation)
# Start scheduler in background thread
def run_scheduler():
while True:
schedule.run_pending()
time.sleep(60)
thread = threading.Thread(target=run_scheduler, daemon=True)
thread.start()Cost optimization requires understanding where your money goes and implementing targeted strategies.
Tokens are the primary cost driver. Optimize at every stage:
class TokenOptimizer:
"""Strategies to reduce token usage."""
def __init__(self, tokenizer_model: str = "gpt-4o"):
import tiktoken
self.encoder = tiktoken.encoding_for_model(tokenizer_model)
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
return len(self.encoder.encode(text))
def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
"""Truncate text to fit within token limit."""
tokens = self.encoder.encode(text)
if len(tokens) <= max_tokens:
return text
truncated_tokens = tokens[:max_tokens]
return self.encoder.decode(truncated_tokens)
def optimize_context(
self,
chunks: list[dict],
max_context_tokens: int = 3000,
include_metadata: bool = False
) -> str:
"""Build optimized context from retrieved chunks."""
context_parts = []
current_tokens = 0
for i, chunk in enumerate(chunks):
# Build chunk text
if include_metadata:
chunk_text = f"[Source {i+1}: {chunk.get('source', 'unknown')}]\n{chunk['content']}\n"
else:
chunk_text = f"{chunk['content']}\n\n"
chunk_tokens = self.count_tokens(chunk_text)
if current_tokens + chunk_tokens > max_context_tokens:
# Try to fit partial chunk
remaining_tokens = max_context_tokens - current_tokens
if remaining_tokens > 100: # Only include if meaningful
truncated = self.truncate_to_tokens(chunk_text, remaining_tokens)
context_parts.append(truncated)
break
context_parts.append(chunk_text)
current_tokens += chunk_tokens
return "".join(context_parts)
def compress_prompt(self, prompt: str) -> str:
"""Apply basic prompt compression techniques."""
# Remove excessive whitespace
import re
prompt = re.sub(r'\n{3,}', '\n\n', prompt)
prompt = re.sub(r' {2,}', ' ', prompt)
# Remove common filler phrases (be careful with this)
fillers = [
"Please note that ",
"It's important to understand that ",
"As you may know, ",
]
for filler in fillers:
prompt = prompt.replace(filler, "")
return prompt.strip()Use smaller, cheaper models when possible:
class CostAwareModelRouter:
"""Route queries to cost-appropriate models."""
# Cost per 1M tokens (as of late 2025)
MODEL_COSTS = {
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
"claude-3-5-sonnet": {"input": 3.00, "output": 15.00},
"claude-3-5-haiku": {"input": 0.25, "output": 1.25},
}
def __init__(self, default_model: str = "gpt-4o-mini"):
self.default = default_model
self.query_classifier = ModelRouter()
def select_model(
self,
query: str,
context_length: int,
quality_requirement: str = "standard" # "basic", "standard", "high"
) -> tuple[str, float]:
"""Select model and estimate cost."""
complexity = self.query_classifier.classify_complexity(query)
# Model selection matrix
selection_matrix = {
("simple", "basic"): "gpt-4o-mini",
("simple", "standard"): "gpt-4o-mini",
("simple", "high"): "gpt-4o",
("medium", "basic"): "gpt-4o-mini",
("medium", "standard"): "gpt-4o",
("medium", "high"): "gpt-4o",
("complex", "basic"): "gpt-4o-mini",
("complex", "standard"): "gpt-4o",
("complex", "high"): "gpt-4o",
}
model = selection_matrix.get(
(complexity, quality_requirement),
self.default
)
# Estimate cost
estimated_input = context_length + len(query.split()) * 1.3
estimated_output = 300 # Average response
costs = self.MODEL_COSTS[model]
estimated_cost = (
(estimated_input / 1_000_000) * costs["input"] +
(estimated_output / 1_000_000) * costs["output"]
)
return model, estimated_cost
def estimate_monthly_cost(
self,
daily_queries: int,
avg_context_tokens: int = 3000,
avg_response_tokens: int = 400,
model_distribution: dict = None
) -> dict:
"""Estimate monthly costs with model distribution."""
if model_distribution is None:
model_distribution = {
"gpt-4o-mini": 0.70, # 70% simple queries
"gpt-4o": 0.30, # 30% complex queries
}
monthly_queries = daily_queries * 30
total_cost = 0
breakdown = {}
for model, percentage in model_distribution.items():
queries = monthly_queries * percentage
costs = self.MODEL_COSTS[model]
input_cost = (queries * avg_context_tokens / 1_000_000) * costs["input"]
output_cost = (queries * avg_response_tokens / 1_000_000) * costs["output"]
model_total = input_cost + output_cost
breakdown[model] = {
"queries": queries,
"input_cost": input_cost,
"output_cost": output_cost,
"total": model_total
}
total_cost += model_total
return {
"monthly_total": total_cost,
"cost_per_query": total_cost / monthly_queries,
"breakdown": breakdown
}Batch embedding is significantly cheaper and faster per document:
class BatchEmbeddingStrategy:
"""Strategies for batch vs real-time embedding."""
def __init__(self):
self.openai = OpenAI()
self.queue: list[dict] = []
self.queue_lock = asyncio.Lock()
async def add_to_queue(self, document: dict) -> str:
"""Add document to batch queue, return job ID."""
job_id = str(uuid.uuid4())
async with self.queue_lock:
self.queue.append({
"id": job_id,
"content": document["content"],
"metadata": document.get("metadata", {}),
"status": "queued"
})
return job_id
async def process_queue(self, batch_size: int = 100) -> int:
"""Process queued documents in batches."""
processed = 0
async with self.queue_lock:
pending = [d for d in self.queue if d["status"] == "queued"]
for i in range(0, len(pending), batch_size):
batch = pending[i:i + batch_size]
texts = [d["content"] for d in batch]
# Batch API call
response = self.openai.embeddings.create(
model="text-embedding-3-small",
input=texts
)
# Update documents with embeddings
for doc, data in zip(batch, response.data):
doc["embedding"] = data.embedding
doc["status"] = "completed"
processed += 1
return processed
def compare_costs(
self,
document_count: int,
avg_tokens_per_doc: int = 500
) -> dict:
"""Compare batch vs real-time embedding costs."""
total_tokens = document_count * avg_tokens_per_doc
cost_per_1m = 0.02 # text-embedding-3-small
# Base embedding cost is the same
embedding_cost = (total_tokens / 1_000_000) * cost_per_1m
# But real-time has overhead
real_time_overhead = 1.2 # API call overhead, no batching benefits
return {
"batch_cost": embedding_cost,
"realtime_cost": embedding_cost * real_time_overhead,
"recommendation": "batch" if document_count > 10 else "realtime",
"batch_api_calls": (document_count // 100) + 1,
"realtime_api_calls": document_count
}def compare_vector_db_costs(
vector_count: int,
queries_per_month: int,
dimension: int = 1536
) -> dict:
"""Compare vector database costs."""
# Storage in GB (rough estimate)
storage_gb = (vector_count * dimension * 4) / (1024**3) # 4 bytes per float32
comparisons = {
"pinecone_serverless": {
"storage_per_gb": 0.33,
"read_per_million": 2.00,
"write_per_million": 2.00,
"monthly": storage_gb * 0.33 + (queries_per_month / 1_000_000) * 2.00,
"notes": "Best for variable workloads"
},
"pinecone_pod_s1": {
"base_monthly": 70, # s1.x1 pod
"included_vectors": 1_000_000,
"monthly": 70 if vector_count <= 1_000_000 else 70 * ((vector_count // 1_000_000) + 1),
"notes": "Predictable pricing, good for steady workloads"
},
"qdrant_cloud": {
"per_million_vectors": 25,
"monthly": 25 * ((vector_count // 1_000_000) + 1),
"notes": "Simple pricing, good performance"
},
"weaviate_cloud": {
"base_monthly": 25, # Sandbox
"production_monthly": 135, # Standard
"monthly": 135 if vector_count > 100_000 else 25,
"notes": "Good for hybrid search"
},
"supabase_pgvector": {
"base_monthly": 25, # Pro plan
"storage_per_gb": 0.125,
"monthly": 25 + (storage_gb * 0.125),
"notes": "Great if already using Supabase"
},
"self_hosted_pgvector": {
"server_monthly": 50, # 4GB RAM VPS
"monthly": 50,
"notes": "Most control, requires ops expertise"
}
}
# Add rankings
sorted_by_cost = sorted(
comparisons.items(),
key=lambda x: x[1]["monthly"]
)
return {
"vector_count": vector_count,
"storage_gb": storage_gb,
"queries_per_month": queries_per_month,
"comparisons": comparisons,
"cheapest": sorted_by_cost[0][0],
"most_expensive": sorted_by_cost[-1][0]
}class CostTracker:
"""Track and report per-query costs."""
def __init__(self):
self.queries: list[dict] = []
def track(
self,
query_id: str,
embedding_tokens: int,
llm_input_tokens: int,
llm_output_tokens: int,
model: str,
cache_hit: bool,
vector_db_queries: int = 1
):
"""Track a single query's costs."""
# Calculate costs
embedding_cost = (embedding_tokens / 1_000_000) * 0.02 # text-embedding-3-small
model_costs = CostAwareModelRouter.MODEL_COSTS.get(model, {"input": 2.50, "output": 10.00})
llm_input_cost = (llm_input_tokens / 1_000_000) * model_costs["input"]
llm_output_cost = (llm_output_tokens / 1_000_000) * model_costs["output"]
# Vector DB cost (varies by provider)
vector_cost = (vector_db_queries / 1_000_000) * 2.00 # Rough average
total_cost = embedding_cost + llm_input_cost + llm_output_cost + vector_cost
# Reduce cost if cache hit
if cache_hit:
total_cost = embedding_cost * 0.1 # Only minimal lookup cost
self.queries.append({
"id": query_id,
"timestamp": datetime.now().isoformat(),
"embedding_cost": embedding_cost,
"llm_input_cost": llm_input_cost,
"llm_output_cost": llm_output_cost,
"vector_cost": vector_cost,
"total_cost": total_cost,
"cache_hit": cache_hit,
"model": model
})
def get_report(self, period_days: int = 30) -> dict:
"""Generate cost report."""
cutoff = datetime.now() - timedelta(days=period_days)
recent = [q for q in self.queries if datetime.fromisoformat(q["timestamp"]) > cutoff]
if not recent:
return {"error": "No queries in period"}
total_cost = sum(q["total_cost"] for q in recent)
cache_hits = sum(1 for q in recent if q["cache_hit"])
# Cost by model
by_model = {}
for q in recent:
model = q["model"]
if model not in by_model:
by_model[model] = {"count": 0, "cost": 0}
by_model[model]["count"] += 1
by_model[model]["cost"] += q["total_cost"]
return {
"period_days": period_days,
"total_queries": len(recent),
"total_cost": total_cost,
"avg_cost_per_query": total_cost / len(recent),
"cache_hit_rate": cache_hits / len(recent),
"estimated_savings_from_cache": (total_cost / (1 - cache_hits/len(recent))) - total_cost if cache_hits < len(recent) else 0,
"by_model": by_model
}As query volume grows, you need patterns that scale horizontally.
# Kubernetes deployment for scaled retrieval service
retrieval_deployment = """
apiVersion: apps/v1
kind: Deployment
metadata:
name: rag-retrieval
spec:
replicas: 3
selector:
matchLabels:
app: rag-retrieval
template:
metadata:
labels:
app: rag-retrieval
spec:
containers:
- name: retrieval
image: your-registry/rag-retrieval:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
env:
- name: VECTOR_DB_URL
valueFrom:
secretKeyRef:
name: rag-secrets
key: vector-db-url
- name: REDIS_URL
valueFrom:
secretKeyRef:
name: rag-secrets
key: redis-url
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 10
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 15
periodSeconds: 20
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: rag-retrieval-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: rag-retrieval
minReplicas: 2
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
"""import asyncio
from redis import Redis
from rq import Queue
import json
class QueuedRAGProcessor:
"""Process RAG queries through a job queue."""
def __init__(self, redis_url: str):
self.redis = Redis.from_url(redis_url)
self.high_priority = Queue('rag_high', connection=self.redis)
self.normal_priority = Queue('rag_normal', connection=self.redis)
self.low_priority = Queue('rag_low', connection=self.redis)
def enqueue(
self,
query: str,
user_id: str,
priority: str = "normal",
callback_url: str = None
) -> str:
"""Enqueue a RAG query for processing."""
job_data = {
"query": query,
"user_id": user_id,
"callback_url": callback_url,
"enqueued_at": datetime.now().isoformat()
}
queue_map = {
"high": self.high_priority,
"normal": self.normal_priority,
"low": self.low_priority
}
queue = queue_map.get(priority, self.normal_priority)
job = queue.enqueue(
'workers.rag_worker.process_query',
job_data,
job_timeout=60
)
return job.id
def get_status(self, job_id: str) -> dict:
"""Get job status."""
from rq.job import Job
job = Job.fetch(job_id, connection=self.redis)
return {
"id": job_id,
"status": job.get_status(),
"result": job.result if job.is_finished else None,
"error": str(job.exc_info) if job.is_failed else None,
"enqueued_at": job.enqueued_at.isoformat() if job.enqueued_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"ended_at": job.ended_at.isoformat() if job.ended_at else None
}
# Worker implementation (workers/rag_worker.py)
async def process_query(job_data: dict) -> dict:
"""Process a single RAG query."""
query = job_data["query"]
user_id = job_data["user_id"]
# Initialize RAG pipeline
pipeline = ProductionRAGPipeline()
# Process with caching
result = await pipeline.query(query)
# If callback URL provided, send result
if job_data.get("callback_url"):
async with aiohttp.ClientSession() as session:
await session.post(
job_data["callback_url"],
json={
"query": query,
"result": result,
"user_id": user_id
}
)
return resultclass ReplicatedVectorStore:
"""Vector store with read replica support."""
def __init__(
self,
primary_url: str,
replica_urls: list[str],
read_preference: str = "replica" # "primary", "replica", "nearest"
):
self.primary = self._connect(primary_url)
self.replicas = [self._connect(url) for url in replica_urls]
self.read_preference = read_preference
self.current_replica_idx = 0
def _connect(self, url: str):
"""Connect to vector store instance."""
# Implementation depends on vector store
pass
def _get_read_connection(self):
"""Get connection for read operations."""
if self.read_preference == "primary":
return self.primary
if self.read_preference == "replica" and self.replicas:
# Round-robin through replicas
conn = self.replicas[self.current_replica_idx]
self.current_replica_idx = (self.current_replica_idx + 1) % len(self.replicas)
return conn
if self.read_preference == "nearest":
# In production, implement latency-based selection
return self.replicas[0] if self.replicas else self.primary
return self.primary
async def search(
self,
query_vector: list[float],
top_k: int = 5
) -> list[dict]:
"""Search using read replica."""
conn = self._get_read_connection()
return await conn.search(query_vector, top_k)
async def upsert(
self,
vectors: list[dict]
):
"""Write to primary only."""
return await self.primary.upsert(vectors)
async def delete(self, ids: list[str]):
"""Delete from primary only."""
return await self.primary.delete(ids)
# pgvector with read replicas
async def setup_pgvector_replication():
"""Example pgvector read replica configuration."""
from asyncpg import create_pool
# Primary for writes
primary_pool = await create_pool(
"postgresql://user:pass@primary:5432/rag",
min_size=5,
max_size=20
)
# Read replicas
replica_pool = await create_pool(
"postgresql://user:pass@replica1:5432/rag",
min_size=10,
max_size=50,
command_timeout=5 # Faster timeout for reads
)
return primary_pool, replica_pool# AWS Lambda for serverless RAG scaling
lambda_config = """
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Resources:
RAGFunction:
Type: AWS::Serverless::Function
Properties:
Handler: app.handler
Runtime: python3.11
MemorySize: 1024
Timeout: 30
Environment:
Variables:
VECTOR_DB_URL: !Ref VectorDBURL
OPENAI_API_KEY: !Ref OpenAIKey
Events:
API:
Type: Api
Properties:
Path: /query
Method: post
# Auto-scaling configuration
ProvisionedConcurrencyConfig:
ProvisionedConcurrentExecutions: 5
AutoPublishAlias: live
DeploymentPreference:
Type: AllAtOnce
RAGFunctionScaling:
Type: AWS::ApplicationAutoScaling::ScalableTarget
Properties:
MaxCapacity: 100
MinCapacity: 5
ResourceId: !Sub function:${RAGFunction}:live
RoleARN: !GetAtt AutoScalingRole.Arn
ScalableDimension: lambda:function:ProvisionedConcurrency
ServiceNamespace: lambda
RAGScalingPolicy:
Type: AWS::ApplicationAutoScaling::ScalingPolicy
Properties:
PolicyName: RAGUtilizationPolicy
PolicyType: TargetTrackingScaling
ScalingTargetId: !Ref RAGFunctionScaling
TargetTrackingScalingPolicyConfiguration:
PredefinedMetricSpecification:
PredefinedMetricType: LambdaProvisionedConcurrencyUtilization
TargetValue: 70
"""
# Kubernetes KEDA scaler for queue-based scaling
keda_config = """
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: rag-queue-scaler
spec:
scaleTargetRef:
name: rag-worker
minReplicaCount: 1
maxReplicaCount: 20
triggers:
- type: redis
metadata:
address: redis:6379
listName: rag_normal
listLength: "10"
- type: redis
metadata:
address: redis:6379
listName: rag_high
listLength: "5"
"""Production RAG systems handle sensitive data. Security cannot be an afterthought.
import re
from typing import Optional
import html
class InputSanitizer:
"""Sanitize user inputs before processing."""
# Patterns that might indicate injection attempts
INJECTION_PATTERNS = [
r"ignore previous instructions",
r"disregard (all |any )?prior",
r"forget (everything|what)",
r"you are now",
r"new persona",
r"pretend (you are|to be)",
r"act as if",
r"system:\s*",
r"\[INST\]",
r"<\|im_start\|>",
]
def __init__(self, max_length: int = 2000):
self.max_length = max_length
self.injection_regex = re.compile(
"|".join(self.INJECTION_PATTERNS),
re.IGNORECASE
)
def sanitize(self, query: str) -> tuple[str, list[str]]:
"""
Sanitize user query.
Returns (sanitized_query, list_of_warnings).
"""
warnings = []
# Length check
if len(query) > self.max_length:
query = query[:self.max_length]
warnings.append(f"Query truncated to {self.max_length} characters")
# HTML escape
query = html.escape(query)
# Check for injection patterns
if self.injection_regex.search(query):
warnings.append("Potential prompt injection detected")
# Don't block, but log and flag for review
# Remove control characters
query = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', query)
# Normalize whitespace
query = ' '.join(query.split())
return query, warnings
def sanitize_context(self, context: str) -> str:
"""Sanitize retrieved context before including in prompt."""
# Remove potential instruction overrides from retrieved content
# This is important if your documents might contain adversarial content
cleaned = context
# Remove markdown code blocks that might contain instructions
cleaned = re.sub(r'```[^`]*```', '[code block removed]', cleaned)
# Remove URLs that might be used for exfiltration
cleaned = re.sub(
r'https?://[^\s]+',
'[URL removed]',
cleaned
)
return cleanedclass OutputFilter:
"""Filter LLM outputs before returning to users."""
def __init__(self):
# Patterns that should never appear in output
self.forbidden_patterns = [
r"api[_-]?key\s*[:=]\s*\S+",
r"password\s*[:=]\s*\S+",
r"secret\s*[:=]\s*\S+",
r"Bearer\s+[A-Za-z0-9\-_]+",
r"sk-[A-Za-z0-9]+", # OpenAI keys
]
self.forbidden_regex = re.compile(
"|".join(self.forbidden_patterns),
re.IGNORECASE
)
def filter(self, response: str) -> tuple[str, bool]:
"""
Filter response for sensitive data.
Returns (filtered_response, was_modified).
"""
was_modified = False
# Check for forbidden patterns
if self.forbidden_regex.search(response):
response = self.forbidden_regex.sub('[REDACTED]', response)
was_modified = True
# Check for potential data leakage
response, pii_modified = self._redact_pii(response)
was_modified = was_modified or pii_modified
return response, was_modified
def _redact_pii(self, text: str) -> tuple[str, bool]:
"""Redact common PII patterns."""
modified = False
patterns = {
# SSN
r'\b\d{3}-\d{2}-\d{4}\b': '[SSN REDACTED]',
# Credit card (basic pattern)
r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b': '[CARD REDACTED]',
# Email (only if looks like internal/sensitive)
r'\b[A-Za-z0-9._%+-]+@(company|internal|corp)\.[A-Za-z]{2,}\b': '[EMAIL REDACTED]',
}
for pattern, replacement in patterns.items():
if re.search(pattern, text):
text = re.sub(pattern, replacement, text)
modified = True
return text, modifiedfrom presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
from presidio_anonymizer.entities import OperatorConfig
class PIIProtector:
"""Detect and mask PII using Microsoft Presidio."""
def __init__(self):
self.analyzer = AnalyzerEngine()
self.anonymizer = AnonymizerEngine()
# Entities to detect
self.entities = [
"PERSON",
"EMAIL_ADDRESS",
"PHONE_NUMBER",
"CREDIT_CARD",
"US_SSN",
"US_PASSPORT",
"LOCATION",
"DATE_TIME",
"NRP", # Nationality, religious, political group
"MEDICAL_LICENSE",
"US_BANK_NUMBER",
]
def analyze(self, text: str) -> list[dict]:
"""Analyze text for PII."""
results = self.analyzer.analyze(
text=text,
entities=self.entities,
language='en'
)
return [
{
"type": r.entity_type,
"start": r.start,
"end": r.end,
"score": r.score,
"text": text[r.start:r.end]
}
for r in results
]
def anonymize(
self,
text: str,
strategy: str = "replace" # "replace", "hash", "mask"
) -> str:
"""Anonymize PII in text."""
# Analyze first
analyzer_results = self.analyzer.analyze(
text=text,
entities=self.entities,
language='en'
)
if not analyzer_results:
return text
# Configure anonymization operators
if strategy == "replace":
operators = {
"DEFAULT": OperatorConfig("replace", {"new_value": "<REDACTED>"}),
"PERSON": OperatorConfig("replace", {"new_value": "<PERSON>"}),
"EMAIL_ADDRESS": OperatorConfig("replace", {"new_value": "<EMAIL>"}),
"PHONE_NUMBER": OperatorConfig("replace", {"new_value": "<PHONE>"}),
}
elif strategy == "hash":
operators = {
"DEFAULT": OperatorConfig("hash", {"hash_type": "sha256"})
}
elif strategy == "mask":
operators = {
"DEFAULT": OperatorConfig("mask", {"chars_to_mask": 4, "masking_char": "*"})
}
else:
operators = {"DEFAULT": OperatorConfig("replace", {"new_value": "<REDACTED>"})}
# Anonymize
anonymized = self.anonymizer.anonymize(
text=text,
analyzer_results=analyzer_results,
operators=operators
)
return anonymized.text
def check_query(self, query: str) -> dict:
"""Check if query contains PII that shouldn't be searched."""
findings = self.analyze(query)
# High-risk PII that should block search
high_risk = ["US_SSN", "CREDIT_CARD", "US_PASSPORT", "US_BANK_NUMBER"]
high_risk_found = [f for f in findings if f["type"] in high_risk]
return {
"contains_pii": len(findings) > 0,
"high_risk": len(high_risk_found) > 0,
"findings": findings,
"recommendation": "block" if high_risk_found else "allow"
}from enum import Enum
from dataclasses import dataclass
class Region(Enum):
US = "us"
EU = "eu"
APAC = "apac"
UK = "uk"
@dataclass
class DataResidencyConfig:
"""Configuration for data residency compliance."""
region: Region
vector_db_endpoint: str
llm_endpoint: str
embedding_endpoint: str
cache_endpoint: str
# Compliance requirements
allow_cross_region_fallback: bool = False
require_encryption_at_rest: bool = True
require_encryption_in_transit: bool = True
audit_all_queries: bool = True
class RegionAwareRAG:
"""RAG system with data residency enforcement."""
REGION_CONFIGS = {
Region.US: DataResidencyConfig(
region=Region.US,
vector_db_endpoint="https://us.vectordb.example.com",
llm_endpoint="https://api.openai.com/v1",
embedding_endpoint="https://api.openai.com/v1",
cache_endpoint="redis://us-cache.example.com",
),
Region.EU: DataResidencyConfig(
region=Region.EU,
vector_db_endpoint="https://eu.vectordb.example.com",
llm_endpoint="https://eu.api.openai.com/v1", # If available
embedding_endpoint="https://eu.api.openai.com/v1",
cache_endpoint="redis://eu-cache.example.com",
# GDPR requirements
require_encryption_at_rest=True,
require_encryption_in_transit=True,
audit_all_queries=True,
),
}
def __init__(self, region: Region):
self.config = self.REGION_CONFIGS[region]
self._validate_config()
def _validate_config(self):
"""Validate region configuration meets requirements."""
if self.config.require_encryption_in_transit:
for endpoint in [
self.config.vector_db_endpoint,
self.config.llm_endpoint,
self.config.embedding_endpoint
]:
if not endpoint.startswith("https"):
raise ValueError(f"Endpoint {endpoint} must use HTTPS")
async def query(
self,
query: str,
user_region: Region
) -> dict:
"""Process query with region enforcement."""
# Verify user region matches data region
if user_region != self.config.region:
if not self.config.allow_cross_region_fallback:
raise ValueError(
f"Cross-region queries not allowed. "
f"User region: {user_region}, Data region: {self.config.region}"
)
# Process query using region-specific endpoints
# Implementation here...
return {"region": self.config.region.value, "result": "..."}import json
import hashlib
from datetime import datetime
from enum import Enum
class AuditEventType(Enum):
QUERY = "query"
RETRIEVAL = "retrieval"
GENERATION = "generation"
CACHE_HIT = "cache_hit"
ACCESS_DENIED = "access_denied"
PII_DETECTED = "pii_detected"
ERROR = "error"
@dataclass
class AuditEvent:
"""Audit log event."""
event_id: str
event_type: AuditEventType
timestamp: str
user_id: str
session_id: str
query_hash: str # Don't log actual query for privacy
metadata: dict
def to_json(self) -> str:
return json.dumps({
"event_id": self.event_id,
"event_type": self.event_type.value,
"timestamp": self.timestamp,
"user_id": self.user_id,
"session_id": self.session_id,
"query_hash": self.query_hash,
"metadata": self.metadata
})
class AuditLogger:
"""Production audit logging for RAG systems."""
def __init__(
self,
log_destination: str = "cloudwatch", # or "elasticsearch", "file"
retention_days: int = 90
):
self.destination = log_destination
self.retention = retention_days
def _hash_query(self, query: str) -> str:
"""Hash query for privacy-preserving logging."""
return hashlib.sha256(query.encode()).hexdigest()[:16]
def log(
self,
event_type: AuditEventType,
user_id: str,
session_id: str,
query: str,
**metadata
):
"""Log an audit event."""
event = AuditEvent(
event_id=str(uuid.uuid4()),
event_type=event_type,
timestamp=datetime.utcnow().isoformat(),
user_id=user_id,
session_id=session_id,
query_hash=self._hash_query(query),
metadata=metadata
)
self._write(event)
def _write(self, event: AuditEvent):
"""Write event to destination."""
if self.destination == "cloudwatch":
self._write_cloudwatch(event)
elif self.destination == "elasticsearch":
self._write_elasticsearch(event)
elif self.destination == "file":
self._write_file(event)
def _write_cloudwatch(self, event: AuditEvent):
"""Write to CloudWatch Logs."""
import boto3
client = boto3.client('logs')
client.put_log_events(
logGroupName='/rag/audit',
logStreamName=datetime.utcnow().strftime('%Y/%m/%d'),
logEvents=[{
'timestamp': int(datetime.utcnow().timestamp() * 1000),
'message': event.to_json()
}]
)
def query_logs(
self,
user_id: str = None,
start_time: datetime = None,
end_time: datetime = None,
event_types: list[AuditEventType] = None
) -> list[AuditEvent]:
"""Query audit logs for compliance reporting."""
# Implementation depends on destination
passfrom enum import Enum
from functools import wraps
class Permission(Enum):
READ = "read"
WRITE = "write"
ADMIN = "admin"
SEARCH_ALL = "search_all"
SEARCH_RESTRICTED = "search_restricted"
@dataclass
class User:
id: str
role: str
permissions: list[Permission]
allowed_namespaces: list[str]
department: str
class AccessController:
"""Control access to RAG resources."""
def __init__(self):
self.role_permissions = {
"admin": [Permission.READ, Permission.WRITE, Permission.ADMIN, Permission.SEARCH_ALL],
"analyst": [Permission.READ, Permission.SEARCH_ALL],
"user": [Permission.READ, Permission.SEARCH_RESTRICTED],
"guest": [Permission.SEARCH_RESTRICTED],
}
def check_permission(
self,
user: User,
required: Permission
) -> bool:
"""Check if user has required permission."""
user_permissions = self.role_permissions.get(user.role, [])
return required in user_permissions or required in user.permissions
def filter_namespaces(
self,
user: User,
requested_namespaces: list[str]
) -> list[str]:
"""Filter namespaces to those user can access."""
if self.check_permission(user, Permission.SEARCH_ALL):
return requested_namespaces
return [ns for ns in requested_namespaces if ns in user.allowed_namespaces]
def filter_results(
self,
user: User,
results: list[dict]
) -> list[dict]:
"""Filter search results based on user access."""
if self.check_permission(user, Permission.SEARCH_ALL):
return results
filtered = []
for result in results:
# Check document-level access
doc_namespace = result.get("metadata", {}).get("namespace", "public")
doc_department = result.get("metadata", {}).get("department")
# User can see public docs, their department's docs, or explicitly allowed namespaces
if (doc_namespace == "public" or
doc_namespace in user.allowed_namespaces or
doc_department == user.department):
filtered.append(result)
return filtered
def require_permission(permission: Permission):
"""Decorator for permission-protected endpoints."""
def decorator(func):
@wraps(func)
async def wrapper(request, *args, **kwargs):
user = request.state.user
access = AccessController()
if not access.check_permission(user, permission):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {permission.value} required"
)
return await func(request, *args, **kwargs)
return wrapper
return decoratorProduction systems must handle failures gracefully. Build reliability into your RAG architecture from the start.
class FallbackRAGPipeline:
"""RAG pipeline with multiple fallback levels."""
def __init__(self):
self.primary_vector_db = PineconeClient()
self.fallback_vector_db = LocalFAISS()
self.primary_llm = "gpt-4o"
self.fallback_llm = "gpt-4o-mini"
self.emergency_llm = "gpt-3.5-turbo"
self.cache = TieredRAGCache()
async def query(
self,
question: str,
context_version: str = "v1"
) -> dict:
"""Query with automatic fallbacks."""
# Level 0: Check cache first
cached, cache_tier = await self.cache.get(question)
if cached:
return {**cached, "source": f"cache_{cache_tier}"}
# Level 1: Try primary pipeline
try:
return await self._query_primary(question)
except Exception as e:
logger.warning(f"Primary pipeline failed: {e}")
# Level 2: Try fallback vector DB with primary LLM
try:
return await self._query_fallback_retrieval(question)
except Exception as e:
logger.warning(f"Fallback retrieval failed: {e}")
# Level 3: Try fallback LLM with primary retrieval
try:
return await self._query_fallback_llm(question)
except Exception as e:
logger.warning(f"Fallback LLM failed: {e}")
# Level 4: Emergency mode - basic response
return await self._emergency_response(question)
async def _query_primary(self, question: str) -> dict:
"""Primary pipeline: full capability."""
embedding = await self._embed(question)
chunks = await self.primary_vector_db.search(embedding, top_k=5)
context = self._build_context(chunks)
response = await self._generate(context, question, self.primary_llm)
return {
"answer": response,
"sources": [c["source"] for c in chunks],
"source": "primary"
}
async def _query_fallback_retrieval(self, question: str) -> dict:
"""Fallback retrieval with local vector DB."""
embedding = await self._embed(question)
chunks = await self.fallback_vector_db.search(embedding, top_k=5)
context = self._build_context(chunks)
response = await self._generate(context, question, self.primary_llm)
return {
"answer": response,
"sources": [c["source"] for c in chunks],
"source": "fallback_retrieval"
}
async def _query_fallback_llm(self, question: str) -> dict:
"""Fallback LLM for when primary is unavailable."""
embedding = await self._embed(question)
chunks = await self.primary_vector_db.search(embedding, top_k=3) # Fewer chunks
context = self._build_context(chunks)
response = await self._generate(context, question, self.fallback_llm)
return {
"answer": response,
"sources": [c["source"] for c in chunks],
"source": "fallback_llm",
"warning": "Using fallback model - response quality may vary"
}
async def _emergency_response(self, question: str) -> dict:
"""Emergency response when all else fails."""
return {
"answer": (
"I apologize, but I'm currently experiencing technical difficulties "
"and cannot provide a complete answer. Please try again in a few minutes, "
"or contact support if the issue persists."
),
"sources": [],
"source": "emergency",
"error": True
}from datetime import datetime, timedelta
from enum import Enum
import asyncio
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreaker:
"""Circuit breaker for external service calls."""
def __init__(
self,
name: str,
failure_threshold: int = 5,
reset_timeout: int = 30,
half_open_max_calls: int = 3
):
self.name = name
self.failure_threshold = failure_threshold
self.reset_timeout = reset_timeout
self.half_open_max_calls = half_open_max_calls
self.state = CircuitState.CLOSED
self.failures = 0
self.successes = 0
self.last_failure_time: datetime = None
self.half_open_calls = 0
self._lock = asyncio.Lock()
async def call(self, func, *args, **kwargs):
"""Execute function with circuit breaker protection."""
async with self._lock:
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
self.state = CircuitState.HALF_OPEN
self.half_open_calls = 0
else:
raise CircuitBreakerOpen(f"Circuit {self.name} is open")
if self.state == CircuitState.HALF_OPEN:
if self.half_open_calls >= self.half_open_max_calls:
raise CircuitBreakerOpen(f"Circuit {self.name} half-open limit reached")
self.half_open_calls += 1
try:
result = await func(*args, **kwargs)
await self._record_success()
return result
except Exception as e:
await self._record_failure()
raise
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt reset."""
if self.last_failure_time is None:
return True
return datetime.now() - self.last_failure_time > timedelta(seconds=self.reset_timeout)
async def _record_success(self):
"""Record successful call."""
async with self._lock:
if self.state == CircuitState.HALF_OPEN:
self.successes += 1
if self.successes >= self.half_open_max_calls:
self.state = CircuitState.CLOSED
self.failures = 0
self.successes = 0
elif self.state == CircuitState.CLOSED:
self.failures = 0
async def _record_failure(self):
"""Record failed call."""
async with self._lock:
self.failures += 1
self.last_failure_time = datetime.now()
if self.state == CircuitState.HALF_OPEN:
self.state = CircuitState.OPEN
elif self.state == CircuitState.CLOSED and self.failures >= self.failure_threshold:
self.state = CircuitState.OPEN
class CircuitBreakerOpen(Exception):
"""Raised when circuit breaker is open."""
pass
# Usage with RAG components
class ResilientRAG:
"""RAG with circuit breakers for all external services."""
def __init__(self):
self.embedding_circuit = CircuitBreaker("embedding", failure_threshold=3)
self.vector_db_circuit = CircuitBreaker("vector_db", failure_threshold=5)
self.llm_circuit = CircuitBreaker("llm", failure_threshold=3, reset_timeout=60)
async def query(self, question: str) -> dict:
"""Query with circuit breaker protection."""
# Embedding with circuit breaker
try:
embedding = await self.embedding_circuit.call(
self._embed,
question
)
except CircuitBreakerOpen:
# Fall back to keyword search
return await self._keyword_search_fallback(question)
# Vector search with circuit breaker
try:
chunks = await self.vector_db_circuit.call(
self._vector_search,
embedding
)
except CircuitBreakerOpen:
return await self._no_context_response(question)
# LLM generation with circuit breaker
try:
response = await self.llm_circuit.call(
self._generate,
chunks,
question
)
except CircuitBreakerOpen:
# Return raw chunks without generation
return {
"answer": "Unable to generate response. Here are relevant documents:",
"raw_chunks": chunks,
"error": "llm_unavailable"
}
return responseclass DegradationLevel(Enum):
FULL = "full" # All features available
REDUCED = "reduced" # Core features only
MINIMAL = "minimal" # Basic functionality
MAINTENANCE = "maintenance" # Read-only, cached responses
class GracefulDegradation:
"""Manage service degradation levels."""
def __init__(self):
self.current_level = DegradationLevel.FULL
self.level_features = {
DegradationLevel.FULL: {
"semantic_search": True,
"llm_generation": True,
"streaming": True,
"caching": True,
"analytics": True,
},
DegradationLevel.REDUCED: {
"semantic_search": True,
"llm_generation": True,
"streaming": False, # Disabled to reduce load
"caching": True,
"analytics": False,
},
DegradationLevel.MINIMAL: {
"semantic_search": True,
"llm_generation": False, # Return raw chunks only
"streaming": False,
"caching": True,
"analytics": False,
},
DegradationLevel.MAINTENANCE: {
"semantic_search": False, # Cache only
"llm_generation": False,
"streaming": False,
"caching": True,
"analytics": False,
},
}
def set_level(self, level: DegradationLevel, reason: str):
"""Set degradation level."""
self.current_level = level
logger.warning(f"Degradation level set to {level.value}: {reason}")
def is_feature_available(self, feature: str) -> bool:
"""Check if feature is available at current level."""
return self.level_features[self.current_level].get(feature, False)
async def query(
self,
question: str,
rag_pipeline,
cache: TieredRAGCache
) -> dict:
"""Execute query at appropriate degradation level."""
features = self.level_features[self.current_level]
# Always check cache first
if features["caching"]:
cached, tier = await cache.get(question)
if cached:
return {**cached, "degradation_level": self.current_level.value}
# Maintenance mode: cache only
if self.current_level == DegradationLevel.MAINTENANCE:
return {
"answer": "Service is in maintenance mode. Only cached responses available.",
"degradation_level": "maintenance",
"cached": False
}
# Minimal mode: retrieval only
if not features["llm_generation"]:
chunks = await rag_pipeline.retrieve(question)
return {
"answer": "Here are relevant documents:",
"chunks": chunks,
"degradation_level": "minimal",
"note": "LLM generation temporarily unavailable"
}
# Reduced/Full mode: normal operation
response = await rag_pipeline.query(
question,
stream=features["streaming"]
)
return {
**response,
"degradation_level": self.current_level.value
}import random
from functools import wraps
class RetryConfig:
"""Configuration for retry behavior."""
def __init__(
self,
max_retries: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retryable_exceptions: tuple = (Exception,)
):
self.max_retries = max_retries
self.base_delay = base_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.jitter = jitter
self.retryable_exceptions = retryable_exceptions
def with_retry(config: RetryConfig = None):
"""Decorator for retrying async functions with exponential backoff."""
if config is None:
config = RetryConfig()
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(config.max_retries + 1):
try:
return await func(*args, **kwargs)
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
break
# Calculate delay with exponential backoff
delay = min(
config.base_delay * (config.exponential_base ** attempt),
config.max_delay
)
# Add jitter to prevent thundering herd
if config.jitter:
delay = delay * (0.5 + random.random())
logger.warning(
f"Attempt {attempt + 1} failed for {func.__name__}: {e}. "
f"Retrying in {delay:.2f}s"
)
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
# Specific retry configs for different services
EMBEDDING_RETRY = RetryConfig(
max_retries=3,
base_delay=0.5,
retryable_exceptions=(TimeoutError, ConnectionError)
)
LLM_RETRY = RetryConfig(
max_retries=2,
base_delay=1.0,
max_delay=10.0,
retryable_exceptions=(TimeoutError, RateLimitError)
)
VECTOR_DB_RETRY = RetryConfig(
max_retries=5,
base_delay=0.2,
retryable_exceptions=(ConnectionError, TimeoutError)
)
# Usage
class ReliableRAG:
@with_retry(EMBEDDING_RETRY)
async def embed(self, text: str) -> list[float]:
response = await self.openai.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
@with_retry(LLM_RETRY)
async def generate(self, context: str, question: str) -> str:
response = await self.openai.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": f"Context:\n{context}"},
{"role": "user", "content": question}
]
)
return response.choices[0].message.content
@with_retry(VECTOR_DB_RETRY)
async def search(self, embedding: list[float]) -> list[dict]:
return await self.vector_db.query(
vector=embedding,
top_k=5
)Here is a complete production RAG pipeline that incorporates all the patterns from this article:
from dataclasses import dataclass
from typing import AsyncIterator
import asyncio
import logging
logger = logging.getLogger(__name__)
@dataclass
class RAGConfig:
"""Configuration for production RAG."""
# Caching
redis_url: str
cache_ttl_hours: int = 24
semantic_cache_threshold: float = 0.95
# Performance
max_context_tokens: int = 4000
target_latency_ms: int = 5000
# Cost
default_model: str = "gpt-4o-mini"
premium_model: str = "gpt-4o"
quality_threshold: str = "standard"
# Security
enable_pii_detection: bool = True
audit_all_queries: bool = True
# Reliability
max_retries: int = 3
circuit_breaker_threshold: int = 5
class ProductionRAGPipeline:
"""Complete production RAG pipeline."""
def __init__(self, config: RAGConfig):
self.config = config
# Initialize components
self.cache = TieredRAGCache(
redis_url=config.redis_url,
semantic_threshold=config.semantic_cache_threshold
)
self.token_optimizer = TokenOptimizer()
self.model_router = CostAwareModelRouter(config.default_model)
self.access_controller = AccessController()
self.audit_logger = AuditLogger()
self.pii_protector = PIIProtector()
self.input_sanitizer = InputSanitizer()
self.output_filter = OutputFilter()
self.degradation = GracefulDegradation()
# Circuit breakers
self.embedding_circuit = CircuitBreaker("embedding")
self.vector_circuit = CircuitBreaker("vector_db")
self.llm_circuit = CircuitBreaker("llm")
async def query(
self,
question: str,
user: User,
stream: bool = True
) -> AsyncIterator[str] | dict:
"""Execute a production RAG query."""
session_id = str(uuid.uuid4())
start_time = time.time()
# 1. Input sanitization
question, warnings = self.input_sanitizer.sanitize(question)
if warnings:
logger.warning(f"Input warnings: {warnings}")
# 2. PII check
if self.config.enable_pii_detection:
pii_check = self.pii_protector.check_query(question)
if pii_check["high_risk"]:
self.audit_logger.log(
AuditEventType.PII_DETECTED,
user.id, session_id, question,
pii_findings=pii_check["findings"]
)
return {"error": "Query contains sensitive information", "blocked": True}
# 3. Access control
if not self.access_controller.check_permission(user, Permission.READ):
self.audit_logger.log(
AuditEventType.ACCESS_DENIED,
user.id, session_id, question
)
raise PermissionError("Access denied")
# 4. Check cache
cached, cache_tier = await self.cache.get(question)
if cached:
self.audit_logger.log(
AuditEventType.CACHE_HIT,
user.id, session_id, question,
cache_tier=cache_tier
)
return cached
# 5. Check degradation level
if not self.degradation.is_feature_available("semantic_search"):
return {
"error": "Service degraded",
"message": "Only cached responses available"
}
# 6. Generate embedding
try:
embedding = await self.embedding_circuit.call(
self._embed_with_retry,
question
)
except CircuitBreakerOpen:
logger.error("Embedding circuit open")
return await self._fallback_response(question)
# 7. Retrieve documents
try:
chunks = await self.vector_circuit.call(
self._search_with_retry,
embedding,
user
)
except CircuitBreakerOpen:
logger.error("Vector DB circuit open")
return await self._fallback_response(question)
# 8. Filter results by access control
chunks = self.access_controller.filter_results(user, chunks)
if not chunks:
return {"answer": "No relevant documents found.", "sources": []}
# 9. Build optimized context
context = self.token_optimizer.optimize_context(
chunks,
max_context_tokens=self.config.max_context_tokens
)
# 10. Select model
model, estimated_cost = self.model_router.select_model(
question,
len(context.split()),
self.config.quality_threshold
)
# 11. Generate response
try:
if stream and self.degradation.is_feature_available("streaming"):
return self._stream_response(
context, question, model, user, session_id, chunks
)
else:
response = await self.llm_circuit.call(
self._generate_with_retry,
context, question, model
)
except CircuitBreakerOpen:
logger.error("LLM circuit open")
return {
"answer": "Unable to generate response.",
"chunks": chunks,
"error": "llm_unavailable"
}
# 12. Filter output
response, was_filtered = self.output_filter.filter(response)
if was_filtered:
logger.warning("Response was filtered for sensitive content")
# 13. Build result
result = {
"answer": response,
"sources": [c.get("source") for c in chunks],
"model": model,
"latency_ms": (time.time() - start_time) * 1000
}
# 14. Cache result
await self.cache.set(question, CachedRAGResponse(
query=question,
answer=response,
sources=result["sources"],
retrieved_chunks=chunks,
confidence_score=0.9, # Calculate based on similarity scores
model_used=model,
tokens_used=len(context.split()) + len(response.split()),
created_at=datetime.now().isoformat(),
context_version="v1"
))
# 15. Audit log
self.audit_logger.log(
AuditEventType.QUERY,
user.id, session_id, question,
model=model,
latency_ms=result["latency_ms"],
chunks_retrieved=len(chunks)
)
return result
async def _stream_response(
self,
context: str,
question: str,
model: str,
user: User,
session_id: str,
chunks: list[dict]
) -> AsyncIterator[str]:
"""Stream response tokens."""
full_response = ""
async for token in self._generate_stream(context, question, model):
full_response += token
yield token
# Filter final response
filtered, _ = self.output_filter.filter(full_response)
# Cache after streaming completes
await self.cache.set(question, CachedRAGResponse(
query=question,
answer=filtered,
sources=[c.get("source") for c in chunks],
retrieved_chunks=chunks,
confidence_score=0.9,
model_used=model,
tokens_used=len(context.split()) + len(filtered.split()),
created_at=datetime.now().isoformat(),
context_version="v1"
))
self.audit_logger.log(
AuditEventType.GENERATION,
user.id, session_id, question,
model=model,
streamed=True
)
@with_retry(EMBEDDING_RETRY)
async def _embed_with_retry(self, text: str) -> list[float]:
"""Embed with retry logic."""
return await self.embedding_service.embed_async(text)
@with_retry(VECTOR_DB_RETRY)
async def _search_with_retry(
self,
embedding: list[float],
user: User
) -> list[dict]:
"""Search with retry logic and namespace filtering."""
namespaces = self.access_controller.filter_namespaces(
user,
["default", "public", user.department]
)
return await self.vector_db.search(embedding, namespaces=namespaces)
@with_retry(LLM_RETRY)
async def _generate_with_retry(
self,
context: str,
question: str,
model: str
) -> str:
"""Generate with retry logic."""
return await self.llm_service.generate(context, question, model)
async def _fallback_response(self, question: str) -> dict:
"""Generate fallback response when primary path fails."""
return {
"answer": (
"I apologize, but I'm experiencing technical difficulties. "
"Please try again in a moment."
),
"sources": [],
"error": True,
"fallback": True
}Building production RAG systems requires thinking beyond the happy path:
Caching is critical: Implement tiered caching (exact, semantic, full response) to reduce costs by 50-70% and latency by 80%.
Optimize at every layer: Embedding generation, vector search, and LLM calls each have distinct optimization strategies. Address all three.
Route intelligently: Use smaller models for simple queries. The cost difference between GPT-4o and GPT-4o-mini can save tens of thousands annually.
Plan for failure: Circuit breakers, fallbacks, and graceful degradation ensure your system survives component failures.
Security is not optional: Input sanitization, output filtering, PII detection, and access control protect both your users and your organization.
Measure everything: Track costs per query, cache hit rates, latency percentiles, and error rates. You cannot optimize what you do not measure.
The patterns in this article have been battle-tested across production deployments handling millions of queries. Start with caching and cost tracking, then layer in reliability patterns as your scale demands.
In Part 12, we will cover evaluation and testing, showing how to measure RAG quality and catch regressions before they reach production.
This article is Part 11 of the "Building RAG Systems: A Platform-by-Platform Guide" series. The techniques and patterns described have been refined through production deployments at Elegant Software Solutions and our client organizations.
Discover more content: