diff --git a/ast_serialize.pyi b/ast_serialize.pyi index cb4b0bb..97982a4 100644 --- a/ast_serialize.pyi +++ b/ast_serialize.pyi @@ -22,7 +22,7 @@ class _ASTData(TypedDict): def parse( fnam: str, - source: str | None = None, + source: str | bytes | None = None, skip_function_bodies: bool = False, python_version: tuple[int, int] | None = None, platform: str | None = None, diff --git a/src/lib.rs b/src/lib.rs index b9c9d99..2a8eb68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ pub mod type_comment; fn parse( py: Python, fnam: String, - source: Option, + source: Option, skip_function_bodies: bool, python_version: Option<(u32, u32)>, platform: Option, @@ -66,14 +66,14 @@ fn parse( let always_true = always_true.unwrap_or_default(); let always_false = always_false.unwrap_or_default(); - - if source.is_some() { - return Err(PyErr::new::( - "Source parsing is not supported yet", - )); - } - let path = Path::new(&fnam); + let options = options::Options::new( + python_version, + platform, + always_true, + always_false, + cache_version, + ); let ( ast_bytes, syntax_errors, @@ -85,17 +85,15 @@ fn parse( source_hash, ) = py .detach(|| { - serialize_ast::serialize_python_file( - path, - skip_function_bodies, - options::Options::new( - python_version, - platform, - always_true, - always_false, - cache_version, - ), - ) + if let Some(src) = source { + let s = match src { + serialize_ast::Source::Text(s) => s, + serialize_ast::Source::Bytes(s) => String::from_utf8(s)?, + }; + serialize_ast::serialize_python_source(s, skip_function_bodies, options) + } else { + serialize_ast::serialize_python_file(path, skip_function_bodies, options) + } }) .map_err(|e| PyErr::new::(e.to_string()))?; diff --git a/src/serialize_ast.rs b/src/serialize_ast.rs index b2be7ec..c96d6d0 100644 --- a/src/serialize_ast.rs +++ b/src/serialize_ast.rs @@ -5,6 +5,7 @@ use std::fmt::Write; use std::path::Path; use anyhow::Result; +use pyo3::types::PyAnyMethods; use ruff_python_ast::token::{TokenKind, Tokens}; use ruff_python_ast::{self as ast, AnyParameterRef, Number, PySourceType, StmtFunctionDef}; use ruff_python_parser::{Mode, ParseOptions, parse_unchecked}; @@ -190,136 +191,40 @@ pub(crate) fn serialize_python_file( bool, String, )> { - let source_type = PySourceType::from(file_path); - let source_text = std::fs::read_to_string(file_path)?; - - // Compute SHA1 hash of the source text (same as mypy's compute_hash) - let hash_hex = { - let hash = Sha1::digest(source_text.as_bytes()); - let mut hex = String::with_capacity(40); - for byte in hash { - write!(hex, "{byte:02x}").unwrap(); - } - hex - }; - let line_index = LineIndex::from_source_text(&source_text); - let is_stub_package = match file_path.file_name() { - Some(file) => file.as_encoded_bytes() == b"__init__.pyi", - _ => false, - }; - - // Check if file is all ASCII and build per-line non-ASCII flags if needed - let is_all_ascii = source_text.is_ascii(); - let lines_with_non_ascii = if is_all_ascii { - Vec::new() // No need to track per-line if whole file is ASCII - } else { - // Build a Vec indicating which lines have non-ASCII characters - source_text.lines().map(|line| !line.is_ascii()).collect() - }; - - // Parse the file - this always returns a result, even with syntax errors - let parsed = parse_unchecked(&source_text, ParseOptions::from(source_type)); - - // Extract syntax errors with location information - let mut syntax_errors: Vec = parsed - .errors() - .iter() - .map(|error| { - let location = line_index.line_column(error.location.start(), &source_text); - SyntaxError { - line: location.line.get(), - column: location.column.get(), - message: error.error.to_string(), - blocker: true, - } - }) - .collect(); - - // Extract both type: ignore comments and type annotation comments in a single pass - let (mut type_ignore_lines, mut mypy_ignore_lines, type_comments) = - extract_type_comments_and_ignores(parsed.tokens(), &source_text, &line_index); - - let mut top_unreachable = false; - let first_ignore = type_ignore_lines.get(0).cloned(); - let first_statement_line = first_statement_line(parsed.syntax(), &source_text, &line_index); - - if first_ignore.is_some() { - let (first_line, codes) = first_ignore.unwrap(); - if first_line < first_statement_line { - top_unreachable = true; - type_ignore_lines = Vec::new(); - if !codes.is_empty() { - let joined = codes.join(", "); - let error = format!( - "Type ignore with error code is not supported for modules; \ - use `# mypy: disable-error-code=\"{}\"`", - joined - ); - syntax_errors.push(SyntaxError { - line: first_line, - column: 0, - message: error, - blocker: false, - }) - } - } - } - - // Serialize the AST (even if partial due to syntax errors) - let mut ser = Serializer { - bytes: Vec::new(), - imports: Vec::new(), - line_index, - tokens: Some(parsed.tokens()), - text: &source_text, + serialize_module( + std::fs::read_to_string(file_path)?, + PySourceType::from(file_path), skip_function_bodies, - in_class: false, - in_function: false, - is_all_ascii, - lines_with_non_ascii, - type_comments, options, - current_unreachable: false, - current_mypy_only: false, - top_level_getattr: false, - is_evaluated: true, - extra_errors: Vec::new(), - skipped_lines: HashSet::new(), - uses_template_strings: false, - }; - if top_unreachable { - // Module is ignored completely. - ser.write_tagged_int(0); - } else { - parsed.syntax().serialize(&mut ser); - } - - // Serialize the collected imports, reusing the moved state from serializer - let import_bytes = serialize_imports( - &ser.imports, - &source_text, - Some(ser.line_index), - Some(is_all_ascii), - Some(ser.lines_with_non_ascii), - ); - - // Return this directly to caller, so that it can check this without deserialization - let is_partial_package = is_stub_package && ser.top_level_getattr; + match file_path.file_name() { + Some(file) => file.as_encoded_bytes() == b"__init__.pyi", + _ => false, + }, + ) +} - syntax_errors.extend(ser.extra_errors); - // Skip type ignores on unreachable lines, so that they are not flagged as unused. - type_ignore_lines.retain(|(line, _)| !ser.skipped_lines.contains(line)); - mypy_ignore_lines.retain(|(line, _)| !ser.skipped_lines.contains(line)); - Ok(( - ser.bytes, - syntax_errors, - type_ignore_lines, - mypy_ignore_lines, - import_bytes, - is_partial_package, - ser.uses_template_strings, - hash_hex, - )) +/// Serialize Python source code to mypy AST format +pub(crate) fn serialize_python_source( + source: String, + skip_function_bodies: bool, + options: Options, +) -> Result<( + Vec, + Vec, + Vec<(usize, Vec)>, + Vec<(usize, Vec)>, + Vec, + bool, + bool, + String, +)> { + serialize_module( + source, + PySourceType::Python, + skip_function_bodies, + options, + false, + ) } // Bit flags for import statement metadata @@ -358,6 +263,28 @@ pub(crate) enum ParsedTypeComment { Invalid(String), // Error message for invalid type comment } +// Represents Python source code originating from either string or bytes +pub(crate) enum Source { + Text(String), + Bytes(Vec), +} + +// Implementation for converting a Python object into `Source` +impl<'a, 'py> pyo3::FromPyObject<'a, 'py> for Source { + type Error = pyo3::PyErr; + + fn extract(obj: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> Result { + if obj.is_instance_of::() { + return Ok(Source::Text(obj.extract::()?)); + } else if obj.is_instance_of::() { + return Ok(Source::Bytes(obj.extract::>()?)); + } + Err(Self::Error::new::( + "Source must be str or bytes", + )) + } +} + struct Serializer<'a> { bytes: Vec, imports: Vec, // Encountered import statements @@ -2534,6 +2461,147 @@ impl Ser for ast::Pattern { } } +fn serialize_module( + source_text: String, + source_type: PySourceType, + skip_function_bodies: bool, + options: Options, + is_stub_package: bool, +) -> Result<( + Vec, + Vec, + Vec<(usize, Vec)>, + Vec<(usize, Vec)>, + Vec, + bool, + bool, + String, +)> { + // Compute SHA1 hash of the source text (same as mypy's compute_hash) + let hash_hex = { + let hash = Sha1::digest(source_text.as_bytes()); + let mut hex = String::with_capacity(40); + for byte in hash { + write!(hex, "{byte:02x}").unwrap(); + } + hex + }; + let line_index = LineIndex::from_source_text(&source_text); + + // Check if file is all ASCII and build per-line non-ASCII flags if needed + let is_all_ascii = source_text.is_ascii(); + let lines_with_non_ascii = if is_all_ascii { + Vec::new() // No need to track per-line if whole file is ASCII + } else { + // Build a Vec indicating which lines have non-ASCII characters + source_text.lines().map(|line| !line.is_ascii()).collect() + }; + + // Parse the file - this always returns a result, even with syntax errors + let parsed = parse_unchecked(&source_text, ParseOptions::from(source_type)); + + // Extract syntax errors with location information + let mut syntax_errors: Vec = parsed + .errors() + .iter() + .map(|error| { + let location = line_index.line_column(error.location.start(), &source_text); + SyntaxError { + line: location.line.get(), + column: location.column.get(), + message: error.error.to_string(), + blocker: true, + } + }) + .collect(); + + // Extract both type: ignore comments and type annotation comments in a single pass + let (mut type_ignore_lines, mut mypy_ignore_lines, type_comments) = + extract_type_comments_and_ignores(parsed.tokens(), &source_text, &line_index); + + let mut top_unreachable = false; + let first_ignore = type_ignore_lines.get(0).cloned(); + let first_statement_line = first_statement_line(parsed.syntax(), &source_text, &line_index); + + if first_ignore.is_some() { + let (first_line, codes) = first_ignore.unwrap(); + if first_line < first_statement_line { + top_unreachable = true; + type_ignore_lines = Vec::new(); + if !codes.is_empty() { + let joined = codes.join(", "); + let error = format!( + "Type ignore with error code is not supported for modules; \ + use `# mypy: disable-error-code=\"{}\"`", + joined + ); + syntax_errors.push(SyntaxError { + line: first_line, + column: 0, + message: error, + blocker: false, + }) + } + } + } + + // Serialize the AST (even if partial due to syntax errors) + let mut ser = Serializer { + bytes: Vec::new(), + imports: Vec::new(), + line_index, + tokens: Some(parsed.tokens()), + text: &source_text, + skip_function_bodies, + in_class: false, + in_function: false, + is_all_ascii, + lines_with_non_ascii, + type_comments, + options, + current_unreachable: false, + current_mypy_only: false, + top_level_getattr: false, + is_evaluated: true, + extra_errors: Vec::new(), + skipped_lines: HashSet::new(), + uses_template_strings: false, + }; + if top_unreachable { + // Module is ignored completely. + ser.write_tagged_int(0); + } else { + parsed.syntax().serialize(&mut ser); + } + + // Serialize the collected imports, reusing the moved state from serializer + let import_bytes = serialize_imports( + &ser.imports, + &source_text, + Some(ser.line_index), + Some(is_all_ascii), + Some(ser.lines_with_non_ascii), + ); + + // Return this directly to caller, so that it can check this without deserialization + let is_partial_package = is_stub_package && ser.top_level_getattr; + + syntax_errors.extend(ser.extra_errors); + // Skip type ignores on unreachable lines, so that they are not flagged as unused. + type_ignore_lines.retain(|(line, _)| !ser.skipped_lines.contains(line)); + mypy_ignore_lines.retain(|(line, _)| !ser.skipped_lines.contains(line)); + Ok(( + ser.bytes, + syntax_errors, + type_ignore_lines, + mypy_ignore_lines, + import_bytes, + is_partial_package, + ser.uses_template_strings, + hash_hex, + )) +} + fn serialize_fstring_elements(ser: &mut Serializer, elems: Vec<&ast::InterpolatedStringElement>) { ser.write_tagged_int(elems.len() as i64); for elem in elems {