diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index b81e2dff..a50817ce 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -211,6 +211,7 @@ set(SPIDER_TDL_ANTLR_GENERATED_SOURCES ) set(SPIDER_TDL_SHARED_SOURCES + tdl/code_gen/python/PyGenerator.cpp tdl/parser/ast/Node.cpp tdl/parser/ast/node_impl/Function.cpp tdl/parser/ast/node_impl/Identifier.cpp @@ -227,12 +228,16 @@ set(SPIDER_TDL_SHARED_SOURCES tdl/parser/ast/node_impl/type_impl/Struct.cpp tdl/parser/ast/utils.cpp tdl/parser/parse.cpp + tdl/pass/analysis/DetectStructCircularDependency.cpp + tdl/pass/analysis/DetectUndefinedStruct.cpp tdl/pass/analysis/StructSpecDependencyGraph.cpp CACHE INTERNAL "spider task definition language shared source files" ) set(SPIDER_TDL_SHARED_HEADERS + tdl/code_gen/python/PyGenerator.hpp + tdl/code_gen/Generator.hpp tdl/Error.hpp tdl/parser/ast/Node.hpp tdl/parser/ast/FloatSpec.hpp @@ -259,7 +264,11 @@ set(SPIDER_TDL_SHARED_HEADERS tdl/parser/Exception.hpp tdl/parser/parse.hpp tdl/parser/SourceLocation.hpp + tdl/pass/analysis/DetectStructCircularDependency.hpp + tdl/pass/analysis/DetectUndefinedStruct.hpp tdl/pass/analysis/StructSpecDependencyGraph.hpp + tdl/pass/Pass.hpp + tdl/pass/utils.hpp CACHE INTERNAL "spider task definition language shared header files" ) diff --git a/src/spider/tdl/code_gen/Generator.hpp b/src/spider/tdl/code_gen/Generator.hpp new file mode 100644 index 00000000..de951862 --- /dev/null +++ b/src/spider/tdl/code_gen/Generator.hpp @@ -0,0 +1,65 @@ +#ifndef SPIDER_TDL_CODE_GEN_GENERATOR_HPP +#define SPIDER_TDL_CODE_GEN_GENERATOR_HPP + +#include +#include +#include + +#include + +#include +#include +#include + +namespace spider::tdl::code_gen { +/** + * Abstract base class for generating code from a translation unit to a target language. + */ +class Generator { +public: + // Constructor + Generator( + std::unique_ptr translation_unit, + std::shared_ptr struct_spec_dependency_graph + ) + : m_translation_unit{std::move(translation_unit)}, + m_struct_spec_dependency_graph{std::move(struct_spec_dependency_graph)} {} + + // Delete copy constructor and copy assignment operator + Generator(Generator const&) = delete; + auto operator=(Generator const&) -> Generator& = delete; + + // Default move constructor and move assignment operator + Generator(Generator&&) noexcept = default; + auto operator=(Generator&&) -> Generator& = delete; + + // Destructor + virtual ~Generator() = default; + + // Methods + /** + * Generates code from the translation unit to the target language. + * @param out_stream Output stream to write the generated code. + * @return A void result on success, or an error specified by an `Error` instance on failure. + */ + [[nodiscard]] virtual auto generate(std::ostream& out_stream) + -> boost::outcome_v2::std_checked + = 0; + +protected: + [[nodiscard]] auto get_translation_unit() const -> parser::ast::TranslationUnit const* { + return m_translation_unit.get(); + } + + [[nodiscard]] auto get_struct_spec_dependency_graph() const + -> std::shared_ptr { + return m_struct_spec_dependency_graph; + } + +private: + std::unique_ptr m_translation_unit; + std::shared_ptr m_struct_spec_dependency_graph; +}; +} // namespace spider::tdl::code_gen + +#endif // SPIDER_TDL_CODE_GEN_GENERATOR_HPP diff --git a/src/spider/tdl/code_gen/python/PyGenerator.cpp b/src/spider/tdl/code_gen/python/PyGenerator.cpp new file mode 100644 index 00000000..d00a076b --- /dev/null +++ b/src/spider/tdl/code_gen/python/PyGenerator.cpp @@ -0,0 +1,573 @@ +#include "PyGenerator.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace spider::tdl::code_gen::python { +namespace { +// The visitor pattern implementation uses recursion. Iterative implementation would take too long +// to implement. We will defer this to future PRs. +// NOLINTBEGIN(misc-no-recursion) + +/** + * Visitor for traversing the translation unit AST and generating Python code. + */ +class Visitor { +public: + // Constructor + explicit Visitor( + std::shared_ptr struct_spec_dependency_graph, + std::ostream& out_stream + ) + : m_struct_spec_dependency_graph{std::move(struct_spec_dependency_graph)}, + m_out_stream{&out_stream} {} + + // Delete copy & move constructor and copy assignment operator + Visitor(Visitor const&) = delete; + Visitor(Visitor&&) noexcept = delete; + auto operator=(Visitor const&) -> Visitor& = delete; + auto operator=(Visitor&&) -> Visitor& = delete; + + // Destructor + ~Visitor() = default; + + // Methods + /** + * Visits a translation unit node and generates code for it. + * @param tu The translation unit node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return An `Error` instance if: + * - The struct specs have circular dependencies. + * - The struct specs' dependency graph is invalid. + * @return Forwards `visit_struct_spec`'s return values on failure. + * @return Forwards `visit_child`'s return values on failure. + */ + [[nodiscard]] auto visit_translation_unit(parser::ast::TranslationUnit const* tu) + -> boost::outcome_v2::std_checked; + +private: + // Methods + /** + * Visits a named variable node and generates code for it. + * @param named_var The named variable node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_named_var`'s return values on failure. + */ + [[nodiscard]] auto visit_struct_spec(parser::ast::StructSpec const* struct_spec) + -> boost::outcome_v2::std_checked; + + /** + * Visits a node by calling the appropriate visit function based on the node type. + * @param node The node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return An `Error` instance if: + * - The node is a nullptr. + * - The node is a `TranslationUnit` or a `StructSpec` (these nodes should not show up as child + * nodes). + * - The node's type is unrecognized. + * @return Forwards `visit_namespace`'s return values on failure. + * @return Forwards `visit_function`'s return values on failure. + * @return Forwards `visit_named_var`'s return values on failure. + * @return Forwards `visit_identifier`'s return values on failure. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_node(parser::ast::Node const* node) + -> boost::outcome_v2::std_checked; + + /** + * Visits a namespace node and generates code for it. + * @param ns The namespace node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_children`'s return values on failure. + */ + [[nodiscard]] auto visit_namespace(parser::ast::Namespace const* ns) + -> boost::outcome_v2::std_checked; + + /** + * Visits a function node and generates code for it. + * @param func The function node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_named_var`'s return values on failure. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_function(parser::ast::Function const* func) + -> boost::outcome_v2::std_checked; + + /** + * Visits a named variable node and generates code for it. + * @param named_var The named variable node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_identifier`'s return values on failure. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_named_var(parser::ast::NamedVar const* named_var) + -> boost::outcome_v2::std_checked; + + /** + * Visits an identifier node and generates code for it. + * @param identifier The identifier node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. This function always succeeds. + */ + [[nodiscard]] auto visit_identifier(parser::ast::Identifier const* identifier) + -> boost::outcome_v2::std_checked { + *m_out_stream << identifier->get_name(); + return ystdlib::error_handling::success(); + } + + /** + * Visits a type node and generates code for it. + * @param type The type node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return An `Error` instance if the type is unknown. + * @return Forwards `visit_primitive_type`'s return values on failure. + * @return Forwards `visit_struct_type`'s return values on failure. + * @return Forwards `visit_list_type`'s return values on failure. + * @return Forwards `visit_map_type`'s return values on failure. + * @return Forwards `visit_tuple_type`'s return values on failure. + */ + [[nodiscard]] auto visit_type(parser::ast::Type const* type) + -> boost::outcome_v2::std_checked; + + /** + * Visits a primitive type node and generates code for it. + * @param primitive_type The primitive type node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return An `Error` instance if: + * - The node is an `Int` but its spec is unsupported. + * - The node is an `Float` but its spec is unsupported. + * - The node's type is unrecognized. + */ + [[nodiscard]] auto visit_primitive_type(parser::ast::Primitive const* primitive_type) + -> boost::outcome_v2::std_checked; + + /** + * Visits a struct type node and generates code for it. + * @param struct_type The struct type node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. This function always succeeds. + */ + [[nodiscard]] auto visit_struct_type(parser::ast::Struct const* struct_type) + -> boost::outcome_v2::std_checked; + + /** + * Visits a list type node and generates code for it. + * @param list_type The list type node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_list_type(parser::ast::List const* list_type) + -> boost::outcome_v2::std_checked; + + /** + * visits a map type node and generates code for it. + * @param map_type the map type node to visit. + * @param out_stream the output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_map_type(parser::ast::Map const* map_type) + -> boost::outcome_v2::std_checked; + + /** + * Visits a tuple type node and generates code for it. + * @param tuple_type The tuple type node to visit. + * @param out_stream The output stream to write the generated code. + * @return A void result on success. + * @return Forwards `visit_type`'s return values on failure. + */ + [[nodiscard]] auto visit_tuple_type(parser::ast::Tuple const* tuple_type) + -> boost::outcome_v2::std_checked; + + /** + * Visits all children of a node by calling `visit_node`. + * @param node The node whose children to visit. + * @param start_from The index of the child to start visiting from. Default is 0. + * @return A void result on success. + * @return Forwards `visit_node`'s return values on failure. + */ + [[nodiscard]] auto visit_children(parser::ast::Node const* node, size_t start_from = 0) + -> boost::outcome_v2::std_checked { + for (size_t child_id{start_from}; child_id < node->get_num_children(); ++child_id) { + YSTDLIB_ERROR_HANDLING_TRYV(visit_node(node->get_child(child_id).value())); + } + return boost::outcome_v2::success(); + } + + auto increase_indent() noexcept -> void { ++m_indent_level; } + + auto decrease_indent() noexcept -> void { + if (m_indent_level > 0) { + --m_indent_level; + } + } + + auto reset_indent() noexcept -> void { m_indent_level = 0; } + + auto generate_indentation() noexcept -> void { + for (size_t i{0}; i < m_indent_level; ++i) { + *m_out_stream << " "; + } + } + + auto generate_newline() noexcept -> void { *m_out_stream << "\n"; } + + // Variables + std::shared_ptr m_struct_spec_dependency_graph; + std::ostream* m_out_stream; + size_t m_indent_level{0}; +}; + +auto Visitor::visit_translation_unit(parser::ast::TranslationUnit const* tu) + -> boost::outcome_v2::std_checked { + reset_indent(); + *m_out_stream << "# Auto-generated Python code from TDL"; + generate_newline(); + generate_newline(); + + *m_out_stream << "from dataclasses import dataclass"; + generate_newline(); + + *m_out_stream << "import spider_py"; + generate_newline(); + + generate_newline(); + generate_newline(); + + auto const optional_topological_ordering{ + m_struct_spec_dependency_graph->get_struct_specs_in_topological_ordering() + }; + if (false == optional_topological_ordering.has_value()) { + return Error{ + "Cannot generate Python code for TDL files with cyclic struct definitions.", + tu->get_source_location() + }; + } + for (auto const struct_spec_id : *optional_topological_ordering) { + auto const struct_spec{ + m_struct_spec_dependency_graph->get_struct_spec_from_id(struct_spec_id) + }; + if (nullptr == struct_spec) { + return Error{ + "Internal error: `StructSpec` ID does not map to a valid `StructSpec`.", + tu->get_source_location() + }; + } + YSTDLIB_ERROR_HANDLING_TRYV(visit_struct_spec(struct_spec.get())); + generate_newline(); + } + + return visit_children(tu); +} + +auto Visitor::visit_struct_spec(parser::ast::StructSpec const* struct_spec) + -> boost::outcome_v2::std_checked { + generate_indentation(); + *m_out_stream << "@dataclass"; + generate_newline(); + *m_out_stream << "class " << struct_spec->get_name() << ":"; + generate_newline(); + + increase_indent(); + std::vector fields; + fields.reserve(struct_spec->get_num_fields()); + std::ignore = struct_spec->visit_fields( + [&](parser::ast::NamedVar const& field) -> ystdlib::error_handling::Result { + fields.emplace_back(&field); + return ystdlib::error_handling::success(); + } + ); + for (auto const& field : fields) { + generate_indentation(); + YSTDLIB_ERROR_HANDLING_TRYV(visit_named_var(field)); + generate_newline(); + } + decrease_indent(); + + generate_newline(); + return ystdlib::error_handling::success(); +} + +auto Visitor::visit_node(parser::ast::Node const* node) + -> boost::outcome_v2::std_checked { + if (nullptr == node) { + return Error{"Internal error: NULL AST node encountered.", parser::SourceLocation{0, 0}}; + } + + if (auto const* tu{dynamic_cast(node)}; nullptr != tu) { + return Error{ + "Internal error: Unexpected `TranslationUnit` node.", + tu->get_source_location() + }; + } + + if (auto const* struct_spec{dynamic_cast(node)}; + nullptr != struct_spec) + { + return Error{ + "Internal error: Unexpected `StructSpec` node.", + struct_spec->get_source_location() + }; + } + + if (auto const* ns{dynamic_cast(node)}; nullptr != ns) { + return visit_namespace(ns); + } + + if (auto const* func{dynamic_cast(node)}; nullptr != func) { + return visit_function(func); + } + + if (auto const* named_var{dynamic_cast(node)}; + nullptr != named_var) + { + return visit_named_var(named_var); + } + + if (auto const* identifier{dynamic_cast(node)}; + nullptr != identifier) + { + return visit_identifier(identifier); + } + + if (auto const* type{dynamic_cast(node)}; nullptr != type) { + return visit_type(type); + } + + return Error{"Internal error: Unknown AST node type.", node->get_source_location()}; +} + +auto Visitor::visit_namespace(parser::ast::Namespace const* ns) + -> boost::outcome_v2::std_checked { + generate_indentation(); + *m_out_stream << "class " << ns->get_name() << ":"; + generate_newline(); + + increase_indent(); + auto const&& result{visit_children(ns, 1)}; + decrease_indent(); + + generate_newline(); + return result; +} + +auto Visitor::visit_function(parser::ast::Function const* func) + -> boost::outcome_v2::std_checked { + generate_indentation(); + *m_out_stream << "@staticmethod"; + generate_newline(); + + generate_indentation(); + *m_out_stream << "def " << func->get_name() << "("; + + // Params + if (0 != func->get_num_params()) { + generate_newline(); + increase_indent(); + std::vector params; + params.reserve(func->get_num_params()); + std::ignore = func->visit_params( + [&](parser::ast::NamedVar const& param) -> ystdlib::error_handling::Result { + params.emplace_back(¶m); + return ystdlib::error_handling::success(); + } + ); + for (auto const& param : params) { + generate_indentation(); + YSTDLIB_ERROR_HANDLING_TRYV(visit_named_var(param)); + *m_out_stream << ","; + generate_newline(); + } + decrease_indent(); + generate_indentation(); + } + + // Return + *m_out_stream << ")"; + if (func->has_return()) { + *m_out_stream << " -> "; + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(func->get_return_type())); + } + *m_out_stream << ":"; + generate_newline(); + + // Body + increase_indent(); + generate_indentation(); + *m_out_stream << "pass"; + generate_newline(); + decrease_indent(); + + generate_newline(); + return ystdlib::error_handling::success(); +} + +auto Visitor::visit_named_var(parser::ast::NamedVar const* named_var) + -> boost::outcome_v2::std_checked { + YSTDLIB_ERROR_HANDLING_TRYV(visit_identifier(named_var->get_id())); + *m_out_stream << ": "; + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(named_var->get_type())); + return ystdlib::error_handling::success(); +} + +auto Visitor::visit_type(parser::ast::Type const* type) + -> boost::outcome_v2::std_checked { + if (auto const* primitive{dynamic_cast(type)}; + nullptr != primitive) + { + return visit_primitive_type(primitive); + } + + if (auto const* struct_type{dynamic_cast(type)}; + nullptr != struct_type) + { + return visit_struct_type(struct_type); + } + + if (auto const* list_type{dynamic_cast(type)}; nullptr != list_type) { + return visit_list_type(list_type); + } + + if (auto const* map_type{dynamic_cast(type)}; nullptr != map_type) { + return visit_map_type(map_type); + } + + if (auto const* tuple_type{dynamic_cast(type)}; + nullptr != tuple_type) + { + return visit_tuple_type(tuple_type); + } + + return Error{"Unknown `Type` node type.", type->get_source_location()}; +} + +auto Visitor::visit_primitive_type(parser::ast::Primitive const* primitive_type) + -> boost::outcome_v2::std_checked { + if (auto const* int_type{dynamic_cast(primitive_type)}; + nullptr != int_type) + { + switch (int_type->get_spec()) { + case parser::ast::IntSpec::Int8: + *m_out_stream << "spider_py.Int8"; + break; + case parser::ast::IntSpec::Int16: + *m_out_stream << "spider_py.Int16"; + break; + case parser::ast::IntSpec::Int32: + *m_out_stream << "spider_py.Int32"; + break; + case parser::ast::IntSpec::Int64: + *m_out_stream << "spider_py.Int64"; + break; + default: + return Error{"Unsupported integer type.", int_type->get_source_location()}; + } + return boost::outcome_v2::success(); + } + + if (auto const* float_type{dynamic_cast(primitive_type)}; + nullptr != float_type) + { + switch (float_type->get_spec()) { + case parser::ast::FloatSpec::Float: + *m_out_stream << "spider_py.Float"; + break; + case parser::ast::FloatSpec::Double: + *m_out_stream << "spider_py.Double"; + break; + default: + return Error{"Unsupported float type.", float_type->get_source_location()}; + } + return boost::outcome_v2::success(); + } + + if (auto const* bool_type{dynamic_cast(primitive_type)}; + nullptr != bool_type) + { + *m_out_stream << "bool"; + return boost::outcome_v2::success(); + } + + return Error{"Unknown `Primitive` type.", primitive_type->get_source_location()}; +} + +auto Visitor::visit_struct_type(parser::ast::Struct const* struct_type) + -> boost::outcome_v2::std_checked { + *m_out_stream << struct_type->get_name(); + return boost::outcome_v2::success(); +} + +auto Visitor::visit_list_type(parser::ast::List const* list_type) + -> boost::outcome_v2::std_checked { + *m_out_stream << "list["; + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(list_type->get_element_type())); + *m_out_stream << "]"; + return boost::outcome_v2::success(); +} + +auto Visitor::visit_map_type(parser::ast::Map const* map_type) + -> boost::outcome_v2::std_checked { + *m_out_stream << "dict["; + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(map_type->get_key_type())); + *m_out_stream << ", "; + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(map_type->get_value_type())); + *m_out_stream << "]"; + return boost::outcome_v2::success(); +} + +auto Visitor::visit_tuple_type(parser::ast::Tuple const* tuple_type) + -> boost::outcome_v2::std_checked { + *m_out_stream << "("; + std::vector element_types; + element_types.reserve(tuple_type->get_num_children()); + std::ignore = tuple_type->visit_children( + [&](parser::ast::Node const& child) -> ystdlib::error_handling::Result { + // The factory function ensures that all children are of type `Type`. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + element_types.emplace_back(static_cast(&child)); + return ystdlib::error_handling::success(); + } + ); + for (size_t i{0}; i < element_types.size(); ++i) { + YSTDLIB_ERROR_HANDLING_TRYV(visit_type(element_types[i])); + if (i + 1 < element_types.size()) { + *m_out_stream << ", "; + } + } + *m_out_stream << ")"; + return boost::outcome_v2::success(); +} + +// NOLINTEND(misc-no-recursion) +} // namespace + +auto PyGenerator::generate(std::ostream& out_stream) + -> boost::outcome_v2::std_checked { + return Visitor{get_struct_spec_dependency_graph(), out_stream}.visit_translation_unit( + get_translation_unit() + ); +} +} // namespace spider::tdl::code_gen::python diff --git a/src/spider/tdl/code_gen/python/PyGenerator.hpp b/src/spider/tdl/code_gen/python/PyGenerator.hpp new file mode 100644 index 00000000..51233673 --- /dev/null +++ b/src/spider/tdl/code_gen/python/PyGenerator.hpp @@ -0,0 +1,31 @@ +#ifndef SPIDER_TDL_CODE_GEN_PY_GENERATOR_HPP +#define SPIDER_TDL_CODE_GEN_PY_GENERATOR_HPP + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace spider::tdl::code_gen::python { +class PyGenerator : public Generator { +public: + // Constructor + PyGenerator( + std::unique_ptr translation_unit, + std::shared_ptr dependency_graph + ) + : Generator{std::move(translation_unit), std::move(dependency_graph)} {} + + // Methods implementing `Generator`. + [[nodiscard]] auto generate(std::ostream& out_stream) + -> boost::outcome_v2::std_checked override; +}; +} // namespace spider::tdl::code_gen::python + +#endif // SPIDER_TDL_CODE_GEN_PY_GENERATOR_HPP diff --git a/src/spider/tdl/parser/SourceLocation.hpp b/src/spider/tdl/parser/SourceLocation.hpp index 7307998c..f5f50b70 100644 --- a/src/spider/tdl/parser/SourceLocation.hpp +++ b/src/spider/tdl/parser/SourceLocation.hpp @@ -1,6 +1,7 @@ #ifndef SPIDER_TDL_PARSER_SOURCELOCATION_HPP #define SPIDER_TDL_PARSER_SOURCELOCATION_HPP +#include #include #include @@ -25,6 +26,16 @@ class SourceLocation { return m_line == other.m_line && m_column == other.m_column; } + [[nodiscard]] auto operator<=>(SourceLocation const& other) const noexcept + -> std::strong_ordering { + if (auto const lint_comparison{m_line <=> other.m_line}; + std::strong_ordering::equal != lint_comparison) + { + return lint_comparison; + } + return m_column <=> other.m_column; + } + private: // Variables size_t m_line; diff --git a/src/spider/tdl/parser/ast/Node.cpp b/src/spider/tdl/parser/ast/Node.cpp index 4dfb40c2..11047c52 100644 --- a/src/spider/tdl/parser/ast/Node.cpp +++ b/src/spider/tdl/parser/ast/Node.cpp @@ -9,15 +9,15 @@ #include using spider::tdl::parser::ast::Node; -using NodeErrorCodeCategory = ystdlib::error_handling::ErrorCategory; +using PyGeneratorErrorCodeCategory = ystdlib::error_handling::ErrorCategory; template <> -auto NodeErrorCodeCategory::name() const noexcept -> char const* { +auto PyGeneratorErrorCodeCategory::name() const noexcept -> char const* { return "spider::tdl::parser::ast::Node"; } template <> -auto NodeErrorCodeCategory ::message(Node::ErrorCodeEnum error_enum) const -> std::string { +auto PyGeneratorErrorCodeCategory ::message(Node::ErrorCodeEnum error_enum) const -> std::string { switch (error_enum) { case Node::ErrorCodeEnum::ChildIndexOutOfBounds: return "The child index is out of bounds."; diff --git a/src/spider/tdl/parser/ast/node_impl/TranslationUnit.hpp b/src/spider/tdl/parser/ast/node_impl/TranslationUnit.hpp index a556c647..9611649a 100644 --- a/src/spider/tdl/parser/ast/node_impl/TranslationUnit.hpp +++ b/src/spider/tdl/parser/ast/node_impl/TranslationUnit.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -47,6 +48,28 @@ class TranslationUnit : public Node { -> ystdlib::error_handling::Result override; // Methods + /** + * Visits all struct specs in the struct spec table in an unspecified order, invoking the given + * `visitor` for each struct spec. + * @tparam StructSpecVisitor + * @param visitor + * @return A void result on success, or an error code indicating the failure: + * - Forwards `visitor`'s return values. + */ + template + requires(std::is_invocable_r_v< + ystdlib::error_handling::Result, + StructSpecVisitor, + StructSpec const* + >) + [[nodiscard]] auto visit_struct_specs(StructSpecVisitor visitor) const + -> ystdlib::error_handling::Result { + for (auto const& [_, struct_spec] : m_struct_spec_table) { + YSTDLIB_ERROR_HANDLING_TRYV(visitor(struct_spec.get())); + } + return ystdlib::error_handling::success(); + } + /** * @param name * @return A shared pointer to the `StructSpec` with the given name if it exists in the struct @@ -82,12 +105,12 @@ class TranslationUnit : public Node { -> ystdlib::error_handling::Result; /** - * @return A newly constructed dependency graph of struct specs defined in this translation - * unit. + * @return A shared pointer pointing to a newly constructed dependency graph of struct specs + * defined in this translation unit. */ [[nodiscard]] auto create_struct_spec_dependency_graph() const - -> pass::analysis::StructSpecDependencyGraph { - return pass::analysis::StructSpecDependencyGraph{m_struct_spec_table}; + -> std::shared_ptr { + return std::make_shared(m_struct_spec_table); } private: diff --git a/src/spider/tdl/pass/Pass.hpp b/src/spider/tdl/pass/Pass.hpp new file mode 100644 index 00000000..70b55f8f --- /dev/null +++ b/src/spider/tdl/pass/Pass.hpp @@ -0,0 +1,63 @@ +#ifndef SPIDER_TDL_PASS_PASS_HPP +#define SPIDER_TDL_PASS_PASS_HPP + +#include +#include + +#include + +namespace spider::tdl::pass { +/** + * Represents an abstract pass over a TDL AST. + */ +class Pass { +public: + // Types + /** + * Represents an abstract error that can occur during the execution of a pass. + */ + class Error { + public: + // Constructor + Error() = default; + + // Delete copy constructor and assignment operator + Error(Error const&) = delete; + auto operator=(Error const&) -> Error& = delete; + + // Default move constructor and assignment operator + Error(Error&&) = default; + auto operator=(Error&&) -> Error& = default; + + // Destructor + virtual ~Error() = default; + + // Methods + [[nodiscard]] virtual auto to_string() const -> std::string = 0; + }; + + // Constructors + Pass() = default; + + // Delete copy constructor and assignment operator + Pass(Pass const&) = delete; + auto operator=(Pass const&) -> Pass& = delete; + + // Default move constructor and assignment operator + Pass(Pass&&) = default; + auto operator=(Pass&&) -> Pass& = default; + + // Destructor + virtual ~Pass() = default; + + // Methods + /** + * Executes the pass. + * @return A void result on success, or a pointer to the error on failure. + */ + [[nodiscard]] virtual auto run() -> boost::outcome_v2::std_checked> + = 0; +}; +} // namespace spider::tdl::pass + +#endif // SPIDER_TDL_PASS_PASS_HPP diff --git a/src/spider/tdl/pass/analysis/DetectStructCircularDependency.cpp b/src/spider/tdl/pass/analysis/DetectStructCircularDependency.cpp new file mode 100644 index 00000000..9afe0479 --- /dev/null +++ b/src/spider/tdl/pass/analysis/DetectStructCircularDependency.cpp @@ -0,0 +1,85 @@ +#include "DetectStructCircularDependency.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace spider::tdl::pass::analysis { +auto DetectStructCircularDependency::Error::to_string() const -> std::string { + std::vector circular_dependency_group_error_messages; + circular_dependency_group_error_messages.reserve(m_strongly_connected_components.size()); + for (auto const& group : m_strongly_connected_components) { + std::vector struct_descriptions; + struct_descriptions.reserve(group.size()); + for (auto const& struct_spec : group) { + struct_descriptions.emplace_back( + fmt::format( + " `{}` at {}", + struct_spec->get_name(), + struct_spec->get_source_location().serialize_to_str() + ) + ); + } + circular_dependency_group_error_messages.emplace_back( + fmt::format( + "Found a circular dependency group of {} struct spec(s):\n{}", + group.size(), + fmt::join(struct_descriptions, "\n") + ) + ); + } + return fmt::format( + "Found {} circular dependency group(s):\n{}", + m_strongly_connected_components.size(), + fmt::join(circular_dependency_group_error_messages, "\n") + ); +} + +auto DetectStructCircularDependency::run() + -> boost::outcome_v2::std_checked> { + auto const& strongly_connected_components{ + m_struct_spec_dependency_graph->get_strongly_connected_components() + }; + if (strongly_connected_components.empty()) { + return boost::outcome_v2::success(); + } + + std::vector>> + circular_dependency_groups; + circular_dependency_groups.reserve(strongly_connected_components.size()); + for (auto const& scc : strongly_connected_components) { + std::vector> group; + group.reserve(scc.size()); + for (auto const id : scc) { + group.emplace_back(m_struct_spec_dependency_graph->get_struct_spec_from_id(id)); + } + std::ranges::sort(group, [](auto const& lhs, auto const& rhs) -> bool { + return lhs->get_source_location() < rhs->get_source_location(); + }); + circular_dependency_groups.emplace_back(std::move(group)); + } + + std::ranges::sort(circular_dependency_groups, [](auto const& lhs, auto const& rhs) -> bool { + // Compare by the source location of the first struct spec in each group. This is safe + // because: + // - Each group is guaranteed to be non-empty. + // - Each struct spec should only appear in one SCC, which guarantees the source locations + // are unique. + return lhs.front()->get_source_location() < rhs.front()->get_source_location(); + }); + + return boost::outcome_v2::failure( + std::make_unique(std::move(circular_dependency_groups)) + ); +} +} // namespace spider::tdl::pass::analysis diff --git a/src/spider/tdl/pass/analysis/DetectStructCircularDependency.hpp b/src/spider/tdl/pass/analysis/DetectStructCircularDependency.hpp new file mode 100644 index 00000000..7ffeb775 --- /dev/null +++ b/src/spider/tdl/pass/analysis/DetectStructCircularDependency.hpp @@ -0,0 +1,63 @@ +#ifndef SPIDER_TDL_PASS_ANALYSIS_DETECTSTRUCTCIRCULARDEPENDENCY_HPP +#define SPIDER_TDL_PASS_ANALYSIS_DETECTSTRUCTCIRCULARDEPENDENCY_HPP + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace spider::tdl::pass::analysis { +/** + * Wrapper of `StructSpecDependencyGraph` to detect circular dependencies among struct specs. + */ +class DetectStructCircularDependency : public Pass { +public: + // Types + /** + * Represents an error including all circular dependency groups (reported as strongly connected + * components). + */ + class Error : public Pass::Error { + public: + // Constructor + explicit Error( + std::vector>> + strongly_connected_components + ) + : m_strongly_connected_components{std::move(strongly_connected_components)} {} + + // Methods implementing `Pass::Error` + [[nodiscard]] auto to_string() const -> std::string override; + + private: + // Variables + std::vector>> + m_strongly_connected_components; + }; + + // Constructor + explicit DetectStructCircularDependency( + std::shared_ptr struct_spec_dependency_graph + ) + : m_struct_spec_dependency_graph{std::move(struct_spec_dependency_graph)} {} + + // Methods implementing `Pass` + /** + * @return A void result on success, or a pointer to `DetectStructCircularDependency::Error` + * on failure. + */ + [[nodiscard]] auto run() + -> boost::outcome_v2::std_checked> override; + +private: + std::shared_ptr m_struct_spec_dependency_graph; +}; +} // namespace spider::tdl::pass::analysis + +#endif // SPIDER_TDL_PASS_ANALYSIS_DETECTSTRUCTCIRCULARDEPENDENCY_HPP diff --git a/src/spider/tdl/pass/analysis/DetectUndefinedStruct.cpp b/src/spider/tdl/pass/analysis/DetectUndefinedStruct.cpp new file mode 100644 index 00000000..9d0da850 --- /dev/null +++ b/src/spider/tdl/pass/analysis/DetectUndefinedStruct.cpp @@ -0,0 +1,74 @@ +#include "DetectUndefinedStruct.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace spider::tdl::pass::analysis { +auto DetectUndefinedStruct::Error::to_string() const -> std::string { + std::vector undefined_struct_error_messages; + undefined_struct_error_messages.reserve(m_undefined_struct.size()); + for (auto const* undefined_struct : m_undefined_struct) { + undefined_struct_error_messages.emplace_back( + fmt::format( + "Referencing to an undefined struct `{}` at {}", + undefined_struct->get_name(), + undefined_struct->get_source_location().serialize_to_str() + ) + ); + } + return fmt::format( + "Found {} undefined struct reference(s):\n{}", + m_undefined_struct.size(), + fmt::join(undefined_struct_error_messages, "\n") + ); +} + +auto DetectUndefinedStruct::run() + -> boost::outcome_v2::std_checked> { + std::vector undefined_structs; + + auto struct_visitor + = [&](parser::ast::Struct const* struct_node) -> ystdlib::error_handling::Result { + if (nullptr == m_translation_unit->get_struct_spec(struct_node->get_name())) { + undefined_structs.emplace_back(struct_node); + } + return ystdlib::error_handling::success(); + }; + + std::ignore = visit_struct_node_using_dfs(m_translation_unit, struct_visitor); + std::ignore = m_translation_unit->visit_struct_specs( + [&]( + parser::ast::StructSpec const* struct_spec + ) -> ystdlib::error_handling::Result { + return visit_struct_node_using_dfs(struct_spec, struct_visitor); + } + ); + + if (undefined_structs.empty()) { + return boost::outcome_v2::success(); + } + std::ranges::sort( + undefined_structs, + [](parser::ast::Struct const* lhs, parser::ast::Struct const* rhs) -> bool { + return lhs->get_source_location() < rhs->get_source_location(); + } + ); + return boost::outcome_v2::failure( + std::make_unique(std::move(undefined_structs)) + ); +} +} // namespace spider::tdl::pass::analysis diff --git a/src/spider/tdl/pass/analysis/DetectUndefinedStruct.hpp b/src/spider/tdl/pass/analysis/DetectUndefinedStruct.hpp new file mode 100644 index 00000000..f1d6e7e3 --- /dev/null +++ b/src/spider/tdl/pass/analysis/DetectUndefinedStruct.hpp @@ -0,0 +1,57 @@ +#ifndef SPIDER_TDL_PASS_ANALYSIS_DETECTUNDEFINEDSTRUCTREF_HPP +#define SPIDER_TDL_PASS_ANALYSIS_DETECTUNDEFINEDSTRUCTREF_HPP + +#include +#include +#include +#include + +#include + +#include +#include + +namespace spider::tdl::pass::analysis { +/** + * A pass that detects undefined struct references in a TDL translation unit. + * NOTE: The translation unit must outlive the pass instance. + */ +class DetectUndefinedStruct : public Pass { +public: + // Types + /** + * Represents an error including all undefined structs. + * NOTE: The translation unit must outlive the error object. + */ + class Error : public Pass::Error { + public: + // Constructor + explicit Error(std::vector undefined_struct) + : m_undefined_struct{std::move(undefined_struct)} {} + + // Methods implementing `Pass::Error` + [[nodiscard]] auto to_string() const -> std::string override; + + private: + // Variables + std::vector m_undefined_struct; + }; + + // Constructor + explicit DetectUndefinedStruct(parser::ast::TranslationUnit const* translation_unit) + : m_translation_unit{translation_unit} {} + + // Methods implementing `Pass` + /** + * @return A void result on success, or a pointer to `DetectUndefinedStruct::Error` on failure. + */ + [[nodiscard]] auto run() + -> boost::outcome_v2::std_checked> override; + +private: + // Variables + parser::ast::TranslationUnit const* m_translation_unit; +}; +} // namespace spider::tdl::pass::analysis + +#endif // SPIDER_TDL_PASS_ANALYSIS_DETECTUNDEFINEDSTRUCTREF_HPP diff --git a/src/spider/tdl/pass/analysis/StructSpecDependencyGraph.cpp b/src/spider/tdl/pass/analysis/StructSpecDependencyGraph.cpp index efad5908..dd3eccd1 100644 --- a/src/spider/tdl/pass/analysis/StructSpecDependencyGraph.cpp +++ b/src/spider/tdl/pass/analysis/StructSpecDependencyGraph.cpp @@ -16,6 +16,7 @@ #include #include +#include namespace spider::tdl::pass::analysis { namespace { @@ -277,40 +278,22 @@ auto collect_use_ids( absl::flat_hash_map const& struct_spec_ids ) -> std::vector { absl::flat_hash_set use_ids; - std::vector ast_dfs_stack; - ast_dfs_stack.emplace_back(def); - while (false == ast_dfs_stack.empty()) { - auto const* node{ast_dfs_stack.back()}; - ast_dfs_stack.pop_back(); - - if (node == nullptr) { - // NOTE: This check is required by clang-tidy. In practice, this should never happen. - continue; - } - - auto const* node_as_struct{dynamic_cast(node)}; - if (nullptr == node_as_struct) { - // Not a struct node, continue DFS by pushing all the child nodes to the stack. - std::ignore = node->visit_children( - [&](parser::ast::Node const& child) -> ystdlib::error_handling::Result { - ast_dfs_stack.emplace_back(&child); - return ystdlib::error_handling::success(); - } - ); - continue; - } - auto const struct_name{node_as_struct->get_name()}; - auto const it{struct_specs.find(struct_name)}; - if (struct_specs.cend() == it) { - // This is a dangling reference, which will be caught in other analysis pass. In this - // dependency graph, we just ignore it. - continue; - } - - use_ids.emplace(struct_spec_ids.at(it->second.get())); - } + std::ignore = visit_struct_node_using_dfs( + def, + [&](parser::ast::Struct const* struct_node) -> ystdlib::error_handling::Result { + auto const struct_name{struct_node->get_name()}; + auto const it{struct_specs.find(struct_name)}; + if (struct_specs.cend() == it) { + // This is a dangling reference, which will be caught in other analysis pass. In + // this dependency graph, we just ignore it. + return ystdlib::error_handling::success(); + } + use_ids.emplace(struct_spec_ids.at(it->second.get())); + return ystdlib::error_handling::success(); + } + ); return std::vector{use_ids.cbegin(), use_ids.cend()}; } } // namespace diff --git a/src/spider/tdl/pass/utils.hpp b/src/spider/tdl/pass/utils.hpp new file mode 100644 index 00000000..a0f3f4df --- /dev/null +++ b/src/spider/tdl/pass/utils.hpp @@ -0,0 +1,66 @@ +#ifndef SPIDER_TDL_PASS_UTILS_HPP +#define SPIDER_TDL_PASS_UTILS_HPP + +#include +#include +#include + +#include + +#include + +namespace spider::tdl::pass { +/** + * Visits all `Struct` nodes in the AST rooted at `root` in a depth-first manner, invoking the given + * `visitor` for each `Struct` node encountered. + * @tparam StructVisitor + * @param root The root to start traversal from. + * @param visitor + * @return A void result on success, or an error code indicating the failure: + * - Forwards `visitor`'s return values. + */ +template +requires std::is_invocable_r_v< + ystdlib::error_handling::Result, + StructVisitor, + parser::ast::Struct const* +> +[[nodiscard]] auto visit_struct_node_using_dfs(parser::ast::Node const* root, StructVisitor visitor) + -> ystdlib::error_handling::Result; + +template +requires std::is_invocable_r_v< + ystdlib::error_handling::Result, + StructVisitor, + parser::ast::Struct const* +> +auto visit_struct_node_using_dfs(parser::ast::Node const* root, StructVisitor visitor) + -> ystdlib::error_handling::Result { + std::vector ast_dfs_stack{root}; + while (false == ast_dfs_stack.empty()) { + auto const* node{ast_dfs_stack.back()}; + ast_dfs_stack.pop_back(); + + if (node == nullptr) { + // NOTE: This check is required by clang-tidy. In practice, this should never happen. + continue; + } + + auto const* node_as_struct{dynamic_cast(node)}; + if (nullptr != node_as_struct) { + YSTDLIB_ERROR_HANDLING_TRYV(visitor(node_as_struct)); + continue; + } + + std::ignore = node->visit_children( + [&](parser::ast::Node const& child) -> ystdlib::error_handling::Result { + ast_dfs_stack.emplace_back(&child); + return ystdlib::error_handling::success(); + } + ); + } + return ystdlib::error_handling::success(); +} +} // namespace spider::tdl::pass + +#endif // SPIDER_TDL_PASS_UTILS_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 44a9e266..1994ccf1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,8 +2,10 @@ set(SPIDER_TEST_SOURCES storage/test-DataStorage.cpp storage/test-MetadataStorage.cpp storage/StorageTestHelper.hpp + tdl/test-codegen-py.cpp tdl/test-parser.cpp tdl/test-parser-ast.cpp + tdl/test-pass-analysis-DetectUndefinedStruct.cpp tdl/test-pass-analysis-StructSpecDependencyGraph.cpp utils/CoreDataUtils.hpp utils/CoreTaskUtils.hpp diff --git a/tests/integration/client.py b/tests/integration/client.py index 05a3fa5c..619880d2 100644 --- a/tests/integration/client.py +++ b/tests/integration/client.py @@ -111,7 +111,7 @@ def is_head_task(task_id: uuid.UUID, dependencies: list[tuple[uuid.UUID, uuid.UU return not any(dependency[1] == task_id for dependency in dependencies) -g_storage_url = "jdbc:mariadb://localhost:3306/spider_test?user=root&password=password" +g_storage_url = "jdbc:mariadb://localhost:3306/spider-storage?user=spider&password=password" @pytest.fixture(scope="session") diff --git a/tests/storage/StorageTestHelper.hpp b/tests/storage/StorageTestHelper.hpp index 737f47c2..36d9f3fb 100644 --- a/tests/storage/StorageTestHelper.hpp +++ b/tests/storage/StorageTestHelper.hpp @@ -12,7 +12,7 @@ namespace spider::test { std::string const cMySqlStorageUrl - = "jdbc:mariadb://localhost:3306/spider_test?user=root&password=password"; + = "jdbc:mariadb://localhost:3306/spider-storage?user=spider&password=password"; using StorageFactoryTypeList = std::tuple; diff --git a/tests/tdl/test-codegen-py.cpp b/tests/tdl/test-codegen-py.cpp new file mode 100644 index 00000000..d229872e --- /dev/null +++ b/tests/tdl/test-codegen-py.cpp @@ -0,0 +1,195 @@ +// NOLINTBEGIN(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + +#include +#include +#include + +#include + +#include +#include + +namespace { +using spider::tdl::code_gen::python::PyGenerator; +using spider::tdl::parser::parse_translation_unit_from_istream; + +constexpr std::string_view cTestCase1{R"(// Start of a TDL file. This is line#1. +namespace test1 { + // Function with no parameters and no return type + fn empty_func(); + + // Function with parameters and return type + fn add(a: int32, b: int32) -> int64; + + // Function that returns an empty tuple + fn return_empty_tuple() -> Tuple<>; + + // Function that returns a Tuple of one element, and takes only one parameter + fn return_singleton_tuple(a: int64) -> Tuple; + + // Function that returns a tuple of containers + fn return_tuple_of_containers() -> Tuple, Map, Map>>>; +} + +struct Input { + field_0: int8, + field_1: int16, + field_2: int32, + field_3: int64, + field_4: float, + field_5: double, + field_6: bool, + field_7: List, + field_8: Map, double>, +}; + +struct Output { + // Notice that the field doesn't end with a comma + processed_input: Map +}; + +namespace test2 { + fn process_input(input: Input, task_id: int64) -> Output; + fn process_inputs(inputs: List, task_id: int64) -> Output; +} +)"}; + +constexpr std::string_view cTestCompress{R"( +namespace Compress { + fn compress( + job_id: int64, + task_id: int64, + tag_ids: List, + clp_io_config_json: List, + paths_to_compress_json: List, + clp_metadata_db_connection_config_json: List + ) -> List; +} +)"}; + +constexpr std::string_view cExpectedCompress{R"(# Auto-generated Python code from TDL + +from dataclasses import dataclass +import spider_py + + +class Compress: + @staticmethod + def compress( + job_id: spider_py.Int64, + task_id: spider_py.Int64, + tag_ids: list[spider_py.Int64], + clp_io_config_json: list[spider_py.Int8], + paths_to_compress_json: list[spider_py.Int8], + clp_metadata_db_connection_config_json: list[spider_py.Int8], + ) -> list[spider_py.Int8]: + pass + + +)"}; + +TEST_CASE("Python Codegen `cTestInput1`", "[tdl][codegen][python]") { + std::istringstream input_stream{std::string{cTestCase1}}; + auto parse_result{parse_translation_unit_from_istream(input_stream)}; + REQUIRE_FALSE(parse_result.has_error()); + + std::ostringstream output_stream; + auto struct_spec_dependency_graph{parse_result.value()->create_struct_spec_dependency_graph()}; + PyGenerator code_generator{ + std::move(parse_result.value()), + std::move(struct_spec_dependency_graph) + }; + auto const codegen_result{code_generator.generate(output_stream)}; + REQUIRE_FALSE(codegen_result.has_error()); + constexpr std::string_view cExpectedGeneratedCode{ + "# Auto-generated Python code from TDL\n" + "\n" + "from dataclasses import dataclass\n" + "import spider_py\n" + "\n" + "\n" + "@dataclass\n" + "class Output:\n" + " processed_input: dict[spider_py.Int64, Input]\n" + "\n" + "\n" + "@dataclass\n" + "class Input:\n" + " field_0: spider_py.Int8\n" + " field_1: spider_py.Int16\n" + " field_2: spider_py.Int32\n" + " field_3: spider_py.Int64\n" + " field_4: spider_py.Float\n" + " field_5: spider_py.Double\n" + " field_6: bool\n" + " field_7: list[spider_py.Int8]\n" + " field_8: dict[list[spider_py.Int8], spider_py.Double]\n" + "\n" + "\n" + "class test1:\n" + " @staticmethod\n" + " def empty_func():\n" + " pass\n" + "\n" + " @staticmethod\n" + " def add(\n" + " a: spider_py.Int32,\n" + " b: spider_py.Int32,\n" + " ) -> spider_py.Int64:\n" + " pass\n" + "\n" + " @staticmethod\n" + " def return_empty_tuple() -> ():\n" + " pass\n" + "\n" + " @staticmethod\n" + " def return_singleton_tuple(\n" + " a: spider_py.Int64,\n" + " ) -> (spider_py.Int32):\n" + " pass\n" + "\n" + " @staticmethod\n" + " def return_tuple_of_containers() -> (list[spider_py.Int8], " + "dict[list[spider_py.Int8]," + " dict[spider_py.Int64, list[spider_py.Int8]]]):\n" + " pass\n" + "\n" + "\n" + "class test2:\n" + " @staticmethod\n" + " def process_input(\n" + " input: Input,\n" + " task_id: spider_py.Int64,\n" + " ) -> Output:\n" + " pass\n" + "\n" + " @staticmethod\n" + " def process_inputs(\n" + " inputs: list[Input],\n" + " task_id: spider_py.Int64,\n" + " ) -> Output:\n" + " pass\n" + "\n" + "\n" + }; + REQUIRE(output_stream.str() == cExpectedGeneratedCode); +} + +TEST_CASE("Python Codegen `cTestCompress`", "[tdl][codegen][python]") { + std::istringstream input_stream{std::string{cTestCompress}}; + auto parse_result{parse_translation_unit_from_istream(input_stream)}; + REQUIRE_FALSE(parse_result.has_error()); + + std::ostringstream output_stream; + auto struct_spec_dependency_graph{parse_result.value()->create_struct_spec_dependency_graph()}; + PyGenerator code_generator{ + std::move(parse_result.value()), + std::move(struct_spec_dependency_graph) + }; + auto const codegen_result{code_generator.generate(output_stream)}; + REQUIRE_FALSE(codegen_result.has_error()); + REQUIRE(output_stream.str() == cExpectedCompress); +} +} // namespace + +// NOLINTEND(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) diff --git a/tests/tdl/test-parser.cpp b/tests/tdl/test-parser.cpp index ce9c5b00..b3f810ba 100644 --- a/tests/tdl/test-parser.cpp +++ b/tests/tdl/test-parser.cpp @@ -273,8 +273,8 @@ TEST_CASE("Parsing `cTestInput1`", "[tdl][parser]") { REQUIRE(serialize_result.value() == cExpectedSerializedAst); auto struct_spec_dependency_graph{translation_unit->create_struct_spec_dependency_graph()}; - REQUIRE(struct_spec_dependency_graph.get_num_struct_specs() == 2); - REQUIRE(struct_spec_dependency_graph.get_strongly_connected_components().empty()); + REQUIRE(struct_spec_dependency_graph->get_num_struct_specs() == 2); + REQUIRE(struct_spec_dependency_graph->get_strongly_connected_components().empty()); } TEST_CASE("Parser errors", "[tdl][parser]") { diff --git a/tests/tdl/test-pass-analysis-DetectUndefinedStruct.cpp b/tests/tdl/test-pass-analysis-DetectUndefinedStruct.cpp new file mode 100644 index 00000000..819dcaa0 --- /dev/null +++ b/tests/tdl/test-pass-analysis-DetectUndefinedStruct.cpp @@ -0,0 +1,87 @@ +// NOLINTBEGIN(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + +#include +#include + +#include + +#include +#include + +namespace { +using spider::tdl::parser::parse_translation_unit_from_istream; +using spider::tdl::pass::analysis::DetectUndefinedStruct; + +constexpr std::string_view cTestCase1{R"(// Start of a TDL file. This is line#1. +struct class0 { + use_0: classN, // Undefined struct reference `classN` + use_1: Map>>, +}; + +namespace ns1 { + fn func0(input: class0) -> class1; + + fn func1( + input: class2 // Undefined struct reference `class2` + ) -> class3; // Undefined struct reference `class3` + + fn func2(a: ClassA, b: ClassB) -> int32; // Undefined struct references `ClassA` and `ClassB` +} + +struct class1 { + use_0: class0, + use_1: Map>>, // Undefined struct reference `class2` +}; +)"}; + +constexpr std::string_view cTestCase2{R"(// Start of a TDL file. This is line#1. +struct class0 { + use_1: Map>>, +}; + +namespace ns1 { + fn func0(input: class0) -> class1; +} + +struct class1 { + use_0: class0, +}; +)"}; + +TEST_CASE("DetectUndefinedStruct Case 1", "[tdl][pass][analytics][DetectUndefinedStruct]") { + std::istringstream input_stream{std::string{cTestCase1}}; + auto const parse_result{parse_translation_unit_from_istream(input_stream)}; + REQUIRE_FALSE(parse_result.has_error()); + auto const& translation_unit{parse_result.value()}; + + auto detect_undefined_struct_pass{DetectUndefinedStruct{translation_unit.get()}}; + auto const run_result{detect_undefined_struct_pass.run()}; + REQUIRE(run_result.has_error()); + auto const* error{dynamic_cast(run_result.error().get())}; + REQUIRE(nullptr != error); + constexpr std::string_view cExpectedErrorMessage{ + "Found 7 undefined struct reference(s):\n" + "Referencing to an undefined struct `classN` at (3:11)\n" + "Referencing to an undefined struct `class2` at (4:37)\n" + "Referencing to an undefined struct `class2` at (11:15)\n" + "Referencing to an undefined struct `class3` at (12:9)\n" + "Referencing to an undefined struct `ClassA` at (14:16)\n" + "Referencing to an undefined struct `ClassB` at (14:27)\n" + "Referencing to an undefined struct `class2` at (19:37)" + }; + REQUIRE(cExpectedErrorMessage == error->to_string()); +} + +TEST_CASE("DetectUndefinedStruct Case 2", "[tdl][pass][analytics][DetectUndefinedStruct]") { + std::istringstream input_stream{std::string{cTestCase2}}; + auto const parse_result{parse_translation_unit_from_istream(input_stream)}; + REQUIRE_FALSE(parse_result.has_error()); + auto const& translation_unit{parse_result.value()}; + + auto detect_undefined_struct_pass{DetectUndefinedStruct{translation_unit.get()}}; + auto const run_result{detect_undefined_struct_pass.run()}; + REQUIRE_FALSE(run_result.has_error()); +} +} // namespace + +// NOLINTEND(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) diff --git a/tests/tdl/test-pass-analysis-StructSpecDependencyGraph.cpp b/tests/tdl/test-pass-analysis-StructSpecDependencyGraph.cpp index 3971893e..040cf142 100644 --- a/tests/tdl/test-pass-analysis-StructSpecDependencyGraph.cpp +++ b/tests/tdl/test-pass-analysis-StructSpecDependencyGraph.cpp @@ -12,10 +12,12 @@ #include #include +#include #include namespace { using spider::tdl::parser::parse_translation_unit_from_istream; +using spider::tdl::pass::analysis::DetectStructCircularDependency; using spider::tdl::pass::analysis::StructSpecDependencyGraph; constexpr std::string_view cTestCase1{R"(// Start of a TDL file. This is line#1. @@ -185,17 +187,17 @@ auto serialize_strongly_connected_components(StructSpecDependencyGraph& graph) return serialized_strongly_connected_components; } -TEST_CASE("Case 1", "[tdl][pass][analytics][StructSpecDependencyGraph]") { +TEST_CASE("StructSpecDependencyGraph Case 1", "[tdl][pass][analytics][StructSpecDependencyGraph]") { std::istringstream input_stream{std::string{cTestCase1}}; auto const parse_result{parse_translation_unit_from_istream(input_stream)}; REQUIRE_FALSE(parse_result.has_error()); auto const& translation_unit{parse_result.value()}; auto struct_spec_dependency_graph{translation_unit->create_struct_spec_dependency_graph()}; - REQUIRE(struct_spec_dependency_graph.get_num_struct_specs() == 7); + REQUIRE(struct_spec_dependency_graph->get_num_struct_specs() == 7); auto const serialized_strongly_connected_components{ - serialize_strongly_connected_components(struct_spec_dependency_graph) + serialize_strongly_connected_components(*struct_spec_dependency_graph) }; REQUIRE(serialized_strongly_connected_components.size() == 4); std::set const expected_serialized_strongly_connected_components{ @@ -208,21 +210,44 @@ TEST_CASE("Case 1", "[tdl][pass][analytics][StructSpecDependencyGraph]") { == expected_serialized_strongly_connected_components); REQUIRE_FALSE( - struct_spec_dependency_graph.get_struct_specs_in_topological_ordering().has_value() + struct_spec_dependency_graph->get_struct_specs_in_topological_ordering().has_value() ); + + DetectStructCircularDependency detect_circular_dependency_pass{struct_spec_dependency_graph}; + auto const run_result{detect_circular_dependency_pass.run()}; + REQUIRE(run_result.has_error()); + auto const* error{ + dynamic_cast(run_result.error().get()) + }; + REQUIRE(nullptr != error); + constexpr std::string_view cExpectedErrorMessage{ + "Found 4 circular dependency group(s):\n" + "Found a circular dependency group of 3 struct spec(s):\n" + " `class0` at (20:0)\n" + " `class1` at (25:0)\n" + " `class3` at (34:0)\n" + "Found a circular dependency group of 2 struct spec(s):\n" + " `class2` at (29:0)\n" + " `class4` at (38:0)\n" + "Found a circular dependency group of 1 struct spec(s):\n" + " `class5` at (42:0)\n" + "Found a circular dependency group of 1 struct spec(s):\n" + " `class6` at (46:0)" + }; + REQUIRE(cExpectedErrorMessage == error->to_string()); } -TEST_CASE("Case 2", "[tdl][pass][analytics][StructSpecDependencyGraph]") { +TEST_CASE("StructSpecDependencyGraphCase 2", "[tdl][pass][analytics][StructSpecDependencyGraph]") { std::istringstream input_stream{std::string{cTestCase2}}; auto const parse_result{parse_translation_unit_from_istream(input_stream)}; REQUIRE_FALSE(parse_result.has_error()); auto const& translation_unit{parse_result.value()}; auto struct_spec_dependency_graph{translation_unit->create_struct_spec_dependency_graph()}; - REQUIRE(struct_spec_dependency_graph.get_num_struct_specs() == 6); + REQUIRE(struct_spec_dependency_graph->get_num_struct_specs() == 6); auto const serialized_strongly_connected_components{ - serialize_strongly_connected_components(struct_spec_dependency_graph) + serialize_strongly_connected_components(*struct_spec_dependency_graph) }; REQUIRE(serialized_strongly_connected_components.size() == 1); std::set const expected_serialized_strongly_connected_components{ @@ -232,26 +257,48 @@ TEST_CASE("Case 2", "[tdl][pass][analytics][StructSpecDependencyGraph]") { == expected_serialized_strongly_connected_components); REQUIRE_FALSE( - struct_spec_dependency_graph.get_struct_specs_in_topological_ordering().has_value() + struct_spec_dependency_graph->get_struct_specs_in_topological_ordering().has_value() ); + + DetectStructCircularDependency detect_circular_dependency_pass{struct_spec_dependency_graph}; + auto const run_result{detect_circular_dependency_pass.run()}; + REQUIRE(run_result.has_error()); + auto const* error{ + dynamic_cast(run_result.error().get()) + }; + REQUIRE(nullptr != error); + constexpr std::string_view cExpectedErrorMessage{ + "Found 1 circular dependency group(s):\n" + "Found a circular dependency group of 5 struct spec(s):\n" + " `class0` at (12:0)\n" + " `class1` at (16:0)\n" + " `class2` at (20:0)\n" + " `class3` at (25:0)\n" + " `class4` at (29:0)" + }; + REQUIRE(cExpectedErrorMessage == error->to_string()); } -TEST_CASE("Case 3", "[tdl][pass][analytics][StructSpecDependencyGraph]") { +TEST_CASE("StructSpecDependencyGraph Case 3", "[tdl][pass][analytics][StructSpecDependencyGraph]") { std::istringstream input_stream{std::string{cTestCase3}}; auto const parse_result{parse_translation_unit_from_istream(input_stream)}; REQUIRE_FALSE(parse_result.has_error()); auto const& translation_unit{parse_result.value()}; auto struct_spec_dependency_graph{translation_unit->create_struct_spec_dependency_graph()}; - REQUIRE(struct_spec_dependency_graph.get_num_struct_specs() == 7); + REQUIRE(struct_spec_dependency_graph->get_num_struct_specs() == 7); auto const serialized_strongly_connected_components{ - serialize_strongly_connected_components(struct_spec_dependency_graph) + serialize_strongly_connected_components(*struct_spec_dependency_graph) }; REQUIRE(serialized_strongly_connected_components.empty()); + DetectStructCircularDependency detect_circular_dependency_pass{struct_spec_dependency_graph}; + auto const run_result{detect_circular_dependency_pass.run()}; + REQUIRE_FALSE(run_result.has_error()); + auto const optional_topological_ordering{ - struct_spec_dependency_graph.get_struct_specs_in_topological_ordering() + struct_spec_dependency_graph->get_struct_specs_in_topological_ordering() }; REQUIRE(optional_topological_ordering.has_value()); std::vector topological_ordering_struct_spec_names; @@ -259,7 +306,7 @@ TEST_CASE("Case 3", "[tdl][pass][analytics][StructSpecDependencyGraph]") { topological_ordering_struct_spec_names.reserve(optional_topological_ordering->size()); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) for (auto const id : *optional_topological_ordering) { - auto const struct_spec{struct_spec_dependency_graph.get_struct_spec_from_id(id)}; + auto const struct_spec{struct_spec_dependency_graph->get_struct_spec_from_id(id)}; REQUIRE(nullptr != struct_spec); topological_ordering_struct_spec_names.emplace_back(struct_spec->get_name()); }