diff --git a/apps/base_rag_example.py b/apps/base_rag_example.py index 25e3b594..67e198b3 100644 --- a/apps/base_rag_example.py +++ b/apps/base_rag_example.py @@ -282,6 +282,22 @@ def get_llm_config(self, args) -> dict[str, Any]: return config + @staticmethod + def _resolve_chunk_token_limit(args) -> int | None: + """Resolve the embedding model's token limit for token-aware chunking. + + Returns ``None`` if the limit cannot be determined (e.g. model unknown). + Apps can pass the result as ``max_tokens_per_chunk=`` to + ``create_text_chunks()``. + """ + try: + from leann.embedding_compute import get_model_token_limit + + base_url = getattr(args, "embedding_api_base", None) + return get_model_token_limit(args.embedding_model, base_url) + except Exception: + return None + async def build_index(self, args, texts: list[dict[str, Any]]) -> str: """Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys).""" index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann") @@ -289,6 +305,17 @@ async def build_index(self, args, texts: list[dict[str, Any]]) -> str: print(f"\n[Building Index] Creating {self.name} index...") print(f"Total text chunks: {len(texts)}") + # Warn if any chunks may exceed the embedding model's token limit + limit = self._resolve_chunk_token_limit(args) + if limit: + try: + from leann.chunking_utils import validate_chunk_token_limits + + _texts = [t["text"] if isinstance(t, dict) else t for t in texts] + validate_chunk_token_limits(_texts, limit) + except Exception: + pass + embedding_options: dict[str, Any] = {} if args.embedding_mode == "ollama": embedding_options["host"] = resolve_ollama_host(args.embedding_host) diff --git a/packages/leann-core/src/leann/chunking_utils.py b/packages/leann-core/src/leann/chunking_utils.py index 965828a9..62eae548 100644 --- a/packages/leann-core/src/leann/chunking_utils.py +++ b/packages/leann-core/src/leann/chunking_utils.py @@ -172,6 +172,27 @@ def get_language_from_extension(file_path: str) -> Optional[str]: return CODE_EXTENSIONS.get(ext) +def _parse_ast_chunk_output(chunk: Any) -> tuple[str | None, dict[str, Any]]: + """Normalize the various chunk output formats from ASTChunkBuilder. + + astchunk can return objects (with a ``.text`` attr), plain strings, or + dicts (``{"content": ..., "metadata": ...}``). This helper returns a + uniform ``(text, metadata)`` pair regardless of the input shape. + """ + if hasattr(chunk, "text"): + return (str(chunk.text) if chunk.text else None, {}) + if isinstance(chunk, str): + return (chunk, {}) + if isinstance(chunk, dict): + meta = chunk.get("metadata", {}) + if "content" in chunk: + return (chunk["content"], meta) + if "text" in chunk: + return (chunk["text"], meta) + return (str(chunk), {}) + return (str(chunk), {}) + + def create_ast_chunks( documents, max_chunk_size: int = 512, @@ -190,14 +211,14 @@ def create_ast_chunks( except ImportError as e: logger.error(f"astchunk not available: {e}") logger.info("Falling back to traditional chunking for code files") - return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap) + return create_traditional_chunks(documents, max_chunk_size, chunk_overlap) all_chunks = [] for doc in documents: language = doc.metadata.get("language") if not language: logger.warning("No language detected; falling back to traditional chunking") - all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) + all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) continue try: @@ -240,24 +261,7 @@ def create_ast_chunks( chunks = chunk_builder.chunkify(code_content) for chunk in chunks: - chunk_text: str | None = None - astchunk_metadata: dict[str, Any] = {} - - if hasattr(chunk, "text"): - chunk_text = str(chunk.text) if chunk.text else None - elif isinstance(chunk, str): - chunk_text = chunk - elif isinstance(chunk, dict): - # Handle astchunk format: {"content": "...", "metadata": {...}} - if "content" in chunk: - chunk_text = chunk["content"] - astchunk_metadata = chunk.get("metadata", {}) - elif "text" in chunk: - chunk_text = chunk["text"] - else: - chunk_text = str(chunk) # Last resort - else: - chunk_text = str(chunk) + chunk_text, astchunk_metadata = _parse_ast_chunk_output(chunk) if chunk_text and chunk_text.strip(): # Extract document-level metadata @@ -282,26 +286,62 @@ def create_ast_chunks( except Exception as e: logger.warning(f"AST chunking failed for {language} file: {e}") logger.info("Falling back to traditional chunking") - all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) + all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) return all_chunks def create_traditional_chunks( - documents, chunk_size: int = 256, chunk_overlap: int = 128 + documents, + chunk_size: int = 256, + chunk_overlap: int = 128, + max_tokens_per_chunk: int | None = None, ) -> list[dict[str, Any]]: """Create traditional text chunks using LlamaIndex SentenceSplitter. + Args: + documents: LlamaIndex Document list. + chunk_size: Target chunk size in **characters** (approximate tokens). + chunk_overlap: Overlap between adjacent chunks in characters. + max_tokens_per_chunk: If set, auto-scale ``chunk_size`` so each + chunk stays within the model's token budget. Uses + ``calculate_safe_chunk_size`` with a 10 % safety margin. + Additionaly runs ``validate_chunk_token_limits`` post-chunk. + Returns: - List of dicts with {"text": str, "metadata": dict} + List of dicts with ``{"text": str, "metadata": dict}``. """ if chunk_size <= 0: logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") chunk_size = 256 if chunk_overlap < 0: chunk_overlap = 0 + + # ── Token-aware auto-scaling ────────────────────────────────────── + if max_tokens_per_chunk and max_tokens_per_chunk > 0: + chunk_size = calculate_safe_chunk_size( + max_tokens_per_chunk, chunk_overlap, chunking_mode="traditional" + ) + logger.info( + "Token-aware chunking: model limit=%d tokens → safe chunk_size=%d chars " + "(overlap=%d, safety=0.9)", + max_tokens_per_chunk, + chunk_size, + chunk_overlap, + ) + + # Revalidate after scaling if chunk_overlap >= chunk_size: - chunk_overlap = chunk_size // 2 + old_overlap = chunk_overlap + chunk_overlap = max(0, chunk_size // 2) + logger.warning( + "Token-aware scaling reduced chunk_size below chunk_overlap " + "(%d → %d); overlap reduced %d → %d", + chunk_size, + old_overlap, + old_overlap, + chunk_overlap, + ) node_parser = SentenceSplitter( chunk_size=chunk_size, @@ -312,8 +352,6 @@ def create_traditional_chunks( result = [] for doc in documents: - # Propagate all document-level metadata to each chunk so custom fields - # (e.g. url/domain for browser_rag) remain available for metadata_filters. doc_metadata = dict(doc.metadata) if doc.metadata else {} doc_metadata.setdefault("file_path", "") doc_metadata.setdefault("file_name", "") @@ -330,17 +368,22 @@ def create_traditional_chunks( if content and content.strip(): result.append({"text": content.strip(), "metadata": doc_metadata}) - return result - - -def _traditional_chunks_as_dicts( - documents, chunk_size: int = 256, chunk_overlap: int = 128 -) -> list[dict[str, Any]]: - """Helper: Traditional chunking that returns dict format for consistency. + # ── Post-chunk validation ───────────────────────────────────────── + if max_tokens_per_chunk and max_tokens_per_chunk > 0 and result: + _texts = [c["text"] for c in result] + _validated, _n = validate_chunk_token_limits(_texts, max_tokens_per_chunk) + if _n > 0: + logger.warning( + "Token-aware chunking: %d/%d chunks truncated to fit %d-token limit", + _n, + len(result), + max_tokens_per_chunk, + ) + for i, chunk_dict in enumerate(result): + if i < len(_validated): + chunk_dict["text"] = _validated[i] - This is now just an alias for create_traditional_chunks for backwards compatibility. - """ - return create_traditional_chunks(documents, chunk_size, chunk_overlap) + return result def create_text_chunks( @@ -352,11 +395,20 @@ def create_text_chunks( ast_chunk_overlap: int = 64, code_file_extensions: Optional[list[str]] = None, ast_fallback_traditional: bool = True, + max_tokens_per_chunk: int | None = None, ) -> list[dict[str, Any]]: """Create text chunks from documents with optional AST support for code files. + Args: + documents: LlamaIndex Document list. + chunk_size: Characters per traditional chunk. + chunk_overlap: Character overlap between traditional chunks. + max_tokens_per_chunk: If set, auto-scale chunk_size to keep each + chunk within the embedding model's token budget. Also runs + post-chunk validation and truncation when necessary. + Returns: - List of dicts with {"text": str, "metadata": dict} + List of dicts with ``{"text": str, "metadata": dict}``. """ if not documents: logger.warning("No documents provided for chunking") @@ -383,8 +435,22 @@ def create_text_chunks( code_docs, text_docs = detect_code_files(documents, local_code_extensions) if code_docs: try: + # AST chunking: auto-scale if token limit given + ast_size = ast_chunk_size + if max_tokens_per_chunk and max_tokens_per_chunk > 0: + ast_size = calculate_safe_chunk_size( + max_tokens_per_chunk, ast_chunk_overlap, chunking_mode="ast" + ) + logger.info( + "Token-aware AST chunking: limit=%d → safe ast_chunk_size=%d chars " + "(overlap=%d, safety=0.9)", + max_tokens_per_chunk, + ast_size, + ast_chunk_overlap, + ) + ast_chunks = create_ast_chunks( - code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap + code_docs, max_chunk_size=ast_size, chunk_overlap=ast_chunk_overlap ) # Prepend line numbers to code chunks for navigation for chunk in ast_chunks: @@ -401,17 +467,32 @@ def create_text_chunks( logger.error(f"AST chunking failed: {e}") if ast_fallback_traditional: all_chunks.extend( - _traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap) + create_traditional_chunks( + code_docs, + chunk_size, + chunk_overlap, + max_tokens_per_chunk=max_tokens_per_chunk, + ) ) else: raise if text_docs: - all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap)) + all_chunks.extend( + create_traditional_chunks( + text_docs, + chunk_size, + chunk_overlap, + max_tokens_per_chunk=max_tokens_per_chunk, + ) + ) else: - all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap) + all_chunks = create_traditional_chunks( + documents, + chunk_size, + chunk_overlap, + max_tokens_per_chunk=max_tokens_per_chunk, + ) logger.info(f"Total chunks created: {len(all_chunks)}") - # Note: Token truncation is now handled at embedding time with dynamic model limits - # See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py return all_chunks