Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions apps/base_rag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,40 @@ def get_llm_config(self, args) -> dict[str, Any]:

return config

@staticmethod
def _resolve_chunk_token_limit(args) -> int | None:
"""Resolve the embedding model's token limit for token-aware chunking.

Returns ``None`` if the limit cannot be determined (e.g. model unknown).
Apps can pass the result as ``max_tokens_per_chunk=`` to
``create_text_chunks()``.
"""
try:
from leann.embedding_compute import get_model_token_limit

base_url = getattr(args, "embedding_api_base", None)
return get_model_token_limit(args.embedding_model, base_url)
except Exception:
return None

async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")

print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")

# Warn if any chunks may exceed the embedding model's token limit
limit = self._resolve_chunk_token_limit(args)
if limit:
try:
from leann.chunking_utils import validate_chunk_token_limits

_texts = [t["text"] if isinstance(t, dict) else t for t in texts]
validate_chunk_token_limits(_texts, limit)
except Exception:
pass

embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
Expand Down
167 changes: 124 additions & 43 deletions packages/leann-core/src/leann/chunking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,27 @@ def get_language_from_extension(file_path: str) -> Optional[str]:
return CODE_EXTENSIONS.get(ext)


def _parse_ast_chunk_output(chunk: Any) -> tuple[str | None, dict[str, Any]]:
"""Normalize the various chunk output formats from ASTChunkBuilder.

astchunk can return objects (with a ``.text`` attr), plain strings, or
dicts (``{"content": ..., "metadata": ...}``). This helper returns a
uniform ``(text, metadata)`` pair regardless of the input shape.
"""
if hasattr(chunk, "text"):
return (str(chunk.text) if chunk.text else None, {})
if isinstance(chunk, str):
return (chunk, {})
if isinstance(chunk, dict):
meta = chunk.get("metadata", {})
if "content" in chunk:
return (chunk["content"], meta)
if "text" in chunk:
return (chunk["text"], meta)
return (str(chunk), {})
return (str(chunk), {})


def create_ast_chunks(
documents,
max_chunk_size: int = 512,
Expand All @@ -190,14 +211,14 @@ def create_ast_chunks(
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)

all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
continue

try:
Expand Down Expand Up @@ -240,24 +261,7 @@ def create_ast_chunks(

chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
chunk_text: str | None = None
astchunk_metadata: dict[str, Any] = {}

if hasattr(chunk, "text"):
chunk_text = str(chunk.text) if chunk.text else None
elif isinstance(chunk, str):
chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else:
chunk_text = str(chunk)
chunk_text, astchunk_metadata = _parse_ast_chunk_output(chunk)

if chunk_text and chunk_text.strip():
# Extract document-level metadata
Expand All @@ -282,26 +286,62 @@ def create_ast_chunks(
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))

return all_chunks


def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
documents,
chunk_size: int = 256,
chunk_overlap: int = 128,
max_tokens_per_chunk: int | None = None,
) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.

Args:
documents: LlamaIndex Document list.
chunk_size: Target chunk size in **characters** (approximate tokens).
chunk_overlap: Overlap between adjacent chunks in characters.
max_tokens_per_chunk: If set, auto-scale ``chunk_size`` so each
chunk stays within the model's token budget. Uses
``calculate_safe_chunk_size`` with a 10 % safety margin.
Additionaly runs ``validate_chunk_token_limits`` post-chunk.

Returns:
List of dicts with {"text": str, "metadata": dict}
List of dicts with ``{"text": str, "metadata": dict}``.
"""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
if chunk_overlap < 0:
chunk_overlap = 0

# ── Token-aware auto-scaling ──────────────────────────────────────
if max_tokens_per_chunk and max_tokens_per_chunk > 0:
chunk_size = calculate_safe_chunk_size(
max_tokens_per_chunk, chunk_overlap, chunking_mode="traditional"
)
logger.info(
"Token-aware chunking: model limit=%d tokens → safe chunk_size=%d chars "
"(overlap=%d, safety=0.9)",
max_tokens_per_chunk,
chunk_size,
chunk_overlap,
)

# Revalidate after scaling
if chunk_overlap >= chunk_size:
chunk_overlap = chunk_size // 2
old_overlap = chunk_overlap
chunk_overlap = max(0, chunk_size // 2)
logger.warning(
"Token-aware scaling reduced chunk_size below chunk_overlap "
"(%d → %d); overlap reduced %d → %d",
chunk_size,
old_overlap,
old_overlap,
chunk_overlap,
)

node_parser = SentenceSplitter(
chunk_size=chunk_size,
Expand All @@ -312,8 +352,6 @@ def create_traditional_chunks(

result = []
for doc in documents:
# Propagate all document-level metadata to each chunk so custom fields
# (e.g. url/domain for browser_rag) remain available for metadata_filters.
doc_metadata = dict(doc.metadata) if doc.metadata else {}
doc_metadata.setdefault("file_path", "")
doc_metadata.setdefault("file_name", "")
Expand All @@ -330,17 +368,22 @@ def create_traditional_chunks(
if content and content.strip():
result.append({"text": content.strip(), "metadata": doc_metadata})

return result


def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
# ── Post-chunk validation ─────────────────────────────────────────
if max_tokens_per_chunk and max_tokens_per_chunk > 0 and result:
_texts = [c["text"] for c in result]
_validated, _n = validate_chunk_token_limits(_texts, max_tokens_per_chunk)
if _n > 0:
logger.warning(
"Token-aware chunking: %d/%d chunks truncated to fit %d-token limit",
_n,
len(result),
max_tokens_per_chunk,
)
for i, chunk_dict in enumerate(result):
if i < len(_validated):
chunk_dict["text"] = _validated[i]

This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
return result


def create_text_chunks(
Expand All @@ -352,11 +395,20 @@ def create_text_chunks(
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
max_tokens_per_chunk: int | None = None,
) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.

Args:
documents: LlamaIndex Document list.
chunk_size: Characters per traditional chunk.
chunk_overlap: Character overlap between traditional chunks.
max_tokens_per_chunk: If set, auto-scale chunk_size to keep each
chunk within the embedding model's token budget. Also runs
post-chunk validation and truncation when necessary.

Returns:
List of dicts with {"text": str, "metadata": dict}
List of dicts with ``{"text": str, "metadata": dict}``.
"""
if not documents:
logger.warning("No documents provided for chunking")
Expand All @@ -383,8 +435,22 @@ def create_text_chunks(
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
if code_docs:
try:
# AST chunking: auto-scale if token limit given
ast_size = ast_chunk_size
if max_tokens_per_chunk and max_tokens_per_chunk > 0:
ast_size = calculate_safe_chunk_size(
max_tokens_per_chunk, ast_chunk_overlap, chunking_mode="ast"
)
logger.info(
"Token-aware AST chunking: limit=%d → safe ast_chunk_size=%d chars "
"(overlap=%d, safety=0.9)",
max_tokens_per_chunk,
ast_size,
ast_chunk_overlap,
)

ast_chunks = create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
code_docs, max_chunk_size=ast_size, chunk_overlap=ast_chunk_overlap
)
# Prepend line numbers to code chunks for navigation
for chunk in ast_chunks:
Expand All @@ -401,17 +467,32 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
create_traditional_chunks(
code_docs,
chunk_size,
chunk_overlap,
max_tokens_per_chunk=max_tokens_per_chunk,
)
)
else:
raise
if text_docs:
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
all_chunks.extend(
create_traditional_chunks(
text_docs,
chunk_size,
chunk_overlap,
max_tokens_per_chunk=max_tokens_per_chunk,
)
)
else:
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
all_chunks = create_traditional_chunks(
documents,
chunk_size,
chunk_overlap,
max_tokens_per_chunk=max_tokens_per_chunk,
)

logger.info(f"Total chunks created: {len(all_chunks)}")

# Note: Token truncation is now handled at embedding time with dynamic model limits
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
return all_chunks
Loading