diff --git a/tools/hls-fuzzer/AST.h b/tools/hls-fuzzer/AST.h index 2b17e6ec8..449f8e31a 100644 --- a/tools/hls-fuzzer/AST.h +++ b/tools/hls-fuzzer/AST.h @@ -322,6 +322,8 @@ class BinaryExpression { /// operand for 'op'. static bool isLegalOperandType(Op op, const ScalarType &datatype); + using SubElements = std::tuple; + private: Expression lhs; Op op; @@ -405,6 +407,8 @@ class ConditionalExpression { llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ConditionalExpression &ternaryExpression); +class ArrayParameter; + /// Expression representing reading and indexing into an array. /// Only array parameters and one-dimensional arrays are currently supported. class ArrayReadExpression { @@ -426,6 +430,11 @@ class ArrayReadExpression { /// element type of the array. const ScalarType &getType() const { return dataType; } + using SubElements = std::tuple; + + constexpr static std::size_t ARRAY_PARAMETER = 0; + constexpr static std::size_t INDEX = 1; + private: ScalarType dataType; std::string arrayParameter; diff --git a/tools/hls-fuzzer/BasicCGenerator.cpp b/tools/hls-fuzzer/BasicCGenerator.cpp index df349d955..416a91fca 100644 --- a/tools/hls-fuzzer/BasicCGenerator.cpp +++ b/tools/hls-fuzzer/BasicCGenerator.cpp @@ -123,79 +123,88 @@ std::optional gen::BasicCGenerator::generateBinaryExpression(ast::BinaryExpression::Op op, const OpaqueContext &context, std::size_t depth) { - auto conclusion = typeSystem.checkBinaryExpressionOpaque(op, context); - if (!conclusion) + if (typeSystem.discardBinaryExpressionOpaque(op, context)) return std::nullopt; - auto [lhsCons, rhsCons] = *conclusion; - - ast::Expression lhs = generateExpression(lhsCons, depth + 1); - ast::Expression rhs = generateExpression(rhsCons, depth + 1); - - // Perform explicit casts to a legal operand type if neither of the - // expressions are legal for the given operation. - // This would e.g. cast 'double's that are meant to be applied to '&' to a - // random type that can be legally used with '&'. - if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) || - !ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) { - - std::optional scalarType = generateScalarType( - context, /*toExclude=*/[&](const ast::ScalarType &value) { - return !ast::BinaryExpression::isLegalOperandType(op, value); - }); - if (!scalarType) - return std::nullopt; - - lhs = safeCastAsNeeded(*scalarType, std::move(lhs)); - rhs = safeCastAsNeeded(*scalarType, std::move(rhs)); - } - switch (op) { - case ast::BinaryExpression::ShiftLeft: - case ast::BinaryExpression::ShiftRight: { - ast::ScalarType datatype = lhs.getType(); - // Restrict the right expression to be in range of the bitwidth. - rhs = ast::BinaryExpression{ - std::move(rhs), ast::BinaryExpression::BitAnd, - ast::Constant{static_cast(datatype.getBitwidth() - 1)}}; - - // If the left-hand side is a signed integer, make sure the value is at - // least 0. - // Performing a left-shift on a negative value in C is undefined behavior. - if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned()) - lhs = generateMinExpression(std::move(lhs), - ast::Constant{static_cast(0)}); - return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; - } - case ast::BinaryExpression::Plus: - case ast::BinaryExpression::Minus: - case ast::BinaryExpression::Mul: { - ast::ScalarType lhsType = lhs.getType(); - ast::ScalarType rhsType = rhs.getType(); - if ((lhsType == ast::PrimitiveType::Int32 && - lhsType.getBitwidth() > rhsType.getBitwidth()) || - (rhsType == ast::PrimitiveType::Int32 && - rhsType.getBitwidth() > lhsType.getBitwidth())) { - // Promote integers where one operand is an 'int32_t' to 'uint32_t' to - // avoid undefined behavior on overflow. - lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs)); - rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs)); - } - return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; - } - // case ast::BinaryExpression::Division: - break; - case ast::BinaryExpression::BitAnd: - case ast::BinaryExpression::BitOr: - case ast::BinaryExpression::BitXor: - case ast::BinaryExpression::Greater: - case ast::BinaryExpression::GreaterEqual: - case ast::BinaryExpression::Less: - case ast::BinaryExpression::LessEqual: - case ast::BinaryExpression::Equal: - case ast::BinaryExpression::NotEqual: - return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; - } - llvm_unreachable("all enum cases handled"); + return generateWithDependencies( + context, typeSystem.getBinaryExpressionContextDependencies(op), + /*lhs=*/ + [&](const OpaqueContext &context) -> ast::Expression { + return generateExpression(context, depth + 1); + }, + /*rhs=*/ + [&](const OpaqueContext &context) -> ast::Expression { + return generateExpression(context, depth + 1); + }, + /*constructor=*/ + [&](ast::Expression &&lhs, + ast::Expression &&rhs) -> std::optional { + // Perform explicit casts to a legal operand type if neither of the + // expressions are legal for the given operation. + // This would e.g. cast 'double's that are meant to be applied to '&' to + // a random type that can be legally used with '&'. + if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) || + !ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) { + + std::optional scalarType = generateScalarType( + context, /*toExclude=*/[&](const ast::ScalarType &value) { + return !ast::BinaryExpression::isLegalOperandType(op, value); + }); + if (!scalarType) + return std::nullopt; + + lhs = safeCastAsNeeded(*scalarType, std::move(lhs)); + rhs = safeCastAsNeeded(*scalarType, std::move(rhs)); + } + + switch (op) { + case ast::BinaryExpression::ShiftLeft: + case ast::BinaryExpression::ShiftRight: { + ast::ScalarType datatype = lhs.getType(); + // Restrict the right expression to be in range of the bitwidth. + rhs = ast::BinaryExpression{ + std::move(rhs), ast::BinaryExpression::BitAnd, + ast::Constant{static_cast(datatype.getBitwidth() - 1)}}; + + // If the left-hand side is a signed integer, make sure the value is + // at least 0. Performing a left-shift on a negative value in C is + // undefined behavior. + if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned()) + lhs = generateMinExpression( + std::move(lhs), ast::Constant{static_cast(0)}); + return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; + } + case ast::BinaryExpression::Plus: + case ast::BinaryExpression::Minus: + case ast::BinaryExpression::Mul: { + ast::ScalarType lhsType = lhs.getType(); + ast::ScalarType rhsType = rhs.getType(); + if ((lhsType == ast::PrimitiveType::Int32 && + lhsType.getBitwidth() > rhsType.getBitwidth()) || + (rhsType == ast::PrimitiveType::Int32 && + rhsType.getBitwidth() > lhsType.getBitwidth())) { + // Promote integers where one operand is an 'int32_t' to 'uint32_t' + // to avoid undefined behavior on overflow. + lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs)); + rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs)); + } + return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; + } + // case ast::BinaryExpression::Division: + break; + case ast::BinaryExpression::BitAnd: + case ast::BinaryExpression::BitOr: + case ast::BinaryExpression::BitXor: + case ast::BinaryExpression::Greater: + case ast::BinaryExpression::GreaterEqual: + case ast::BinaryExpression::Less: + case ast::BinaryExpression::LessEqual: + case ast::BinaryExpression::Equal: + case ast::BinaryExpression::NotEqual: + return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)}; + } + llvm_unreachable("all enum cases handled"); + }); } std::optional @@ -304,41 +313,43 @@ gen::BasicCGenerator::generateConstant(const OpaqueContext &context, std::optional gen::BasicCGenerator::generateArrayReadExpression(const OpaqueContext &context, std::size_t depth) { - auto conclusion = typeSystem.checkArrayReadExpressionOpaque(context); - if (!conclusion) + if (typeSystem.discardArrayReadExpressionOpaque(context)) return std::nullopt; - auto [paramConc, indexConc] = *conclusion; - - // Construct a safe indexing expression from an array parameter. - auto genWrappedArrayReadFromParam = [&, &indexConc = indexConc]( - const ast::ArrayParameter ¶m) { - ast::ScalarType elementType = param.getElementType(); - std::size_t mask = param.getDimension() - 1; - std::string name = param.getName().str(); - // Generate an indexing expression. - // Has to be an integer. - ast::Expression index = safeCastAsNeeded( - ast::PrimitiveType::UInt32, generateExpression(indexConc, depth + 1)); - - // Bitmask the index to be in range of the array! We use this to avoid - // undefined behavior in our programs. In the future we could also add - // mechanisms (type systems, or whatever), that restrict expressions to - // safe in-range expressions. - // - // Note: We can use a bitmask here since array parameters that we generate - // are all powers-of-2. We do so since the modulo operator is currently - // unsupported in dynamatic. - return ast::ArrayReadExpression{ - std::move(elementType), name, - ast::BinaryExpression{std::move(index), ast::BinaryExpression::BitAnd, - ast::Constant{static_cast(mask)}}}; - }; - std::optional arrayParameter = - generateArrayParameter(paramConc); - if (!arrayParameter) - return std::nullopt; - return genWrappedArrayReadFromParam(*arrayParameter); + return generateWithDependencies( + context, typeSystem.getArrayReadExpressionContextDependencies(), + /*array parameter=*/ + [&](const OpaqueContext &context) -> std::optional { + return generateArrayParameter(context); + }, + /*index=*/ + [&](const OpaqueContext &context) -> std::optional { + return generateExpression(context, depth + 1); + }, + /*constructor=*/ + [&](ast::ArrayParameter &¶m, ast::Expression &&expression) { + ast::ScalarType elementType = param.getElementType(); + std::size_t mask = param.getDimension() - 1; + std::string name = param.getName().str(); + // Generate an indexing expression. + // Has to be an integer. + ast::Expression index = + safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(expression)); + + // Bitmask the index to be in range of the array! We use this to avoid + // undefined behavior in our programs. In the future we could also add + // mechanisms (type systems, or whatever), that restrict expressions to + // safe in-range expressions. + // + // Note: We can use a bitmask here since array parameters that we + // generate are all powers-of-2. We do so since the modulo operator is + // currently unsupported in dynamatic. + return ast::ArrayReadExpression{ + std::move(elementType), name, + ast::BinaryExpression{ + std::move(index), ast::BinaryExpression::BitAnd, + ast::Constant{static_cast(mask)}}}; + }); } std::optional diff --git a/tools/hls-fuzzer/BasicCGenerator.h b/tools/hls-fuzzer/BasicCGenerator.h index 4ddb539b0..8a3dd44cf 100644 --- a/tools/hls-fuzzer/BasicCGenerator.h +++ b/tools/hls-fuzzer/BasicCGenerator.h @@ -141,6 +141,152 @@ class BasicCGenerator { std::size_t varCounter = 0; AbstractTypeSystem &typeSystem; OpaqueContext entryContext; + + /// Returns a tuple of 'std::integral_constant's for every element in 'is'. + template + constexpr static auto getIndicesTuple(std::index_sequence) { + return std::tuple{std::integral_constant{}...}; + } + + template + struct GenerateWithDependencies; + + template + struct GenerateWithDependencies> { + std::optional + operator()(const OpaqueContext &parentContext, + const TransferFnArray &transferFunctions, + llvm::function_ref< + std::optional(OpaqueContext)>... generators, + llvm::function_ref(SubElements &&...)> + constructor) const { + typename OpaqueTransferFn::SubElementsTuple subElements; + + // TODO: For now subelement generators cannot yet return an output + // context. We assume output context == input context. + typename OpaqueTransferFn::ContextTuple contexts; + std::get(contexts) = parentContext; + + // Calculate a topological order between all dependencies. + // To do so we use a worklist of elements whose dependencies are all + // satisfied and an edge list that for every node 'i', contains all + // outgoing edges. + // This is opposite from 'OpaqueTransferFn' which returns the incoming + // edges. + + // Note: We use 'std::array' here everywhere since the bounds are known + // and small. + using NodeList = std::array; + std::size_t workListSize = 0; + NodeList worklist{}; + + // For a given node 'i', contains the number of outgoing edges from that + // node. + NodeList forwardEdgeCount{}; + // For a given node 'i', contains the destinations of each outgoing edge + // from that node. + std::array forwardEdgeList{}; + // For a given node 'i', contains the number of incoming edges into 'i'. + NodeList incomingEdgeCount{}; + for (auto &&[index, iter] : + llvm::enumerate(llvm::ArrayRef(transferFunctions).drop_back())) { + if (iter.getInputDependencies().empty() || + iter.getInputDependencies() == llvm::ArrayRef{PARENT_DEPENDENCY}) { + // No dependency (besides the parent context which is satisfied). + worklist[workListSize++] = index; + continue; + } + + // Build the outgoing edge list but do keep track of the number of + // incoming edges. + for (auto fromIndex : iter.getInputDependencies()) { + if (fromIndex == PARENT_DEPENDENCY) + continue; + + forwardEdgeList[fromIndex][forwardEdgeCount[fromIndex]++] = index; + ++incomingEdgeCount[index]; + } + } + + std::size_t topoOrderSize = 0; + NodeList topoOrder{}; + while (workListSize > 0) { + std::size_t index = worklist[--workListSize]; + topoOrder[topoOrderSize++] = index; + // "Remove" all outgoing edges from 'index'. + // If a node has no more incoming edges, then it can be scheduled and + // added to the worklist. + for (auto &&m : llvm::ArrayRef(forwardEdgeList[index]) + .take_front(forwardEdgeCount[index])) + if (--incomingEdgeCount[m] == 0) + worklist[workListSize++] = m; + } + + assert(topoOrderSize == sizeof...(SubElements) && + "transfer function dependency graph contains cycles"); + + // Finally, generate the subelements in topological order. + for (std::size_t iter : topoOrder) { + // We need to use fold-expressions over compile time constants to be + // able to index into 'contexts' and 'subElements'. + // The conditional-expressions are just if-conditions that perform a + // given assignment if 'iter' matches that current 'index'. + bool success = std::apply( + [&](auto &&...indices) { + return ([&](auto indexT) { + if (iter != indexT) + return true; + + constexpr std::size_t index = decltype(indexT){}; + + auto &context = std::get(contexts); + // First generate the context for the subelement. + context = transferFunctions[iter](subElements, contexts); + // Now generate the subelement. + std::get(subElements) = + std::get(std::make_tuple(generators...))(*context); + // Check whether we were successful. + return std::get(subElements).has_value(); + }(indices) && + ...); + }, + getIndicesTuple(std::index_sequence_for{})); + + // Discard this AST node if we failed to generate a subelement. + if (!success) + return std::nullopt; + } + // Lastly, generate the output context. + std::get(contexts) = + transferFunctions[sizeof...(SubElements)](subElements, contexts); + + // And call the constructor with all subelements. + // It should be safe to dereference all optionals since they have been + // guaranteed to have been generated. + return std::apply( + [&](auto &&...values) { return constructor(std::move(*values)...); }, + std::move(subElements)); + } + }; + + /// Callable object used to generate an 'ASTNode' from its subelements. + /// The signature of the object can be thought of as: + /// + /// (const OpaqueContext &parentContext, + /// const DependencyArray &dependencies, + /// llvm::function_ref< + /// std::optional(OpaqueContext)>... generators, + /// llvm::function_ref(SubElements &&...)> + /// constructor) -> std::optional + /// where 'SubElements' are the subelements of 'ASTNode' specified in + /// 'TypeSystemTraits::SubElements'. + /// + /// 'parentContext' is the input context, 'generators' are callbacks to + /// generate every corresponding subelement of 'ASTNode' and 'constructor' + /// the final callback to construct 'ASTNode' from the subelements. + template + constexpr static auto generateWithDependencies = + GenerateWithDependencies{}; }; } // namespace dynamatic::gen diff --git a/tools/hls-fuzzer/TypeSystem.h b/tools/hls-fuzzer/TypeSystem.h index f40481e92..5786efd2d 100644 --- a/tools/hls-fuzzer/TypeSystem.h +++ b/tools/hls-fuzzer/TypeSystem.h @@ -4,14 +4,16 @@ #include "Randomly.h" #include "TypeSystemTraits.h" +#include "llvm/ADT/FunctionExtras.h" + #include namespace dynamatic::gen { /// Opaque wrapper which type-erases a context used during type checking. -/// It allows users of 'AbstractTypeSystem' to pass contexts returned by -/// 'check*' methods, to other 'check*' methods without needing to know the -/// real context type used by the underlying type system. +/// It allows users of 'AbstractTypeSystem' to pass contexts around between +/// methods without needing to know the real context type used by the underlying +/// type system. /// /// We call the type opaque since it does not implement any behavior based /// on the contained context beyond being able to pass it around. @@ -31,10 +33,224 @@ class OpaqueContext { return *std::any_cast(&container); } + // Enable noop casts to 'OpaqueContext'. + template <> + const OpaqueContext &cast() const { + return *this; + } + private: std::any container; }; +/// Sentinel value representing a dependency on the parent context. +constexpr std::size_t PARENT_DEPENDENCY = -1; + +/// Class responsible for telling the generator how to calculate the input +/// 'TypingContext' for a given subelement of 'ASTNode'. +/// The subelement whose input-context we are calculating for is given by its +/// position within 'DependencyTuple'. See that type definition for more +/// information. +/// +/// The class is called 'Dependency' as it allows specifying a dependency on +/// previously calculated contexts + previously generated subelements using +/// 'inputIndices'. +/// The indices in 'inputIndices' refer to the index of the given subelement +/// this instance depends on within 'TypeSystemTraits::SubElements'. +/// The special value 'PARENT_DEPENDENCY' represents depending on the +/// input-context of 'ASTNode'. +/// It is the user's responsibility to not create cyclic dependencies. +template +class TransferFn { + + template + struct CalcCompFn { + // Recursive case. + using type = typename CalcCompFn< + decltype(std::tuple_cat( + std::declval(), + std::declval, + // Add both the context and the ASTNode to the arguments. + std::tuple< + const TypingContext &, + const std::tuple_element_t< + std::min(current, std::tuple_size_v< + typename ASTNode::SubElements> - + 1), + typename ASTNode::SubElements> &>>>())), + remaining...>::type; + }; + + // Terminating end-case + template + struct CalcCompFn, current> { + using type = TypingContext(Args...); + }; + + using ContextComputationFn = + typename CalcCompFn, inputIndices..., 0>::type; + +public: + /// Constructs a 'Dependency' from a function. + /// The signature of the function is dependent on 'inputIndices'. + /// Specifically, for every element of 'inputIndices' and in the order as + /// given in 'inputIndices', the arguments are: + /// * The parent 'TypingContext' if the value is 'PARENT_DEPENDENCY' + /// * The output 'TypingContext' of the 'i'th subelement of 'ASTNode' followed + /// by the subelement's AST node itself. + /// + /// Example: + /// Dependency( + /// [](const Context& rhsContext, const ast::Expression& rhs, + /// const Context& parentContext) -> Context { + /// ... + /// } + /// ) + /// + /// The function should always return a 'TypingContext'. All parameters are + /// passed as const-references. + explicit TransferFn(std::function computationFn) + : computationFn(std::move(computationFn)) {} + + /// Convenience constructor from a constant 'TypingContext' without any + /// dependencies. + explicit TransferFn(TypingContext context) + : TransferFn( + [context = std::move(context)](auto &&...) { return context; }) {} + + template + TypingContext operator()(Args &&...args) const { + return computationFn(std::forward(args)...); + } + +private: + static_assert(((inputIndices < + std::tuple_size_v || + inputIndices == PARENT_DEPENDENCY) && + ...), + "input indices must refer to subelements or the parent"); + + std::function computationFn; +}; + +/// Opaque-wrapper over 'TransferFn' that can be constructed from any instance +/// of 'TransferFn' with the same 'ASTNode'. +/// Users should construct 'TransferFn' instances instead. +/// +/// Mainly used as a return type in 'AbstractTypeSystem' where templates cannot +/// or shouldn't be used. +template +class OpaqueTransferFn { + template + struct OpaqueContextTupleImpl; + + template + struct OpaqueContextTupleImpl> { + using type = std::tuple< + std::optional>..., + std::optional>; + }; + + template + struct NonTerminalsTupleImpl; + + template + struct NonTerminalsTupleImpl> { + using type = std::tuple...>; + }; + +public: + /// Tuple of optionals of all subelements of this ASTNode. + /// This is used to have one consistent API with which to call an + /// 'OpaqueDependency' to calculate a context. + /// Elements are optional, since they may not yet have been constructed. + using SubElementsTuple = + typename NonTerminalsTupleImpl::type; + + /// Tuple of optionals of all contexts of this ASTNode. + /// This is used to have one consistent API with which to call an + /// 'OpaqueDependency' to calculate a context. + /// Elements are optional, since they may not yet have been calculated. + using ContextTuple = + typename OpaqueContextTupleImpl::type; + + /// Constructs an 'OpaqueDependency' from a 'Dependency'. + template + /*implicit*/ OpaqueTransferFn( + TransferFn &&dep) + : dep(std::move(dep)), + computationFn(+[](const std::any &dep, + const SubElementsTuple &subElements, + const ContextTuple &contexts) -> OpaqueContext { + // Construct a tuple of all arguments that 'dep' should be called + // with. + // This mainly uses 'inputIndices' to index into 'subElements' and + // 'contexts'. + // The logic here simply unwraps the optionals: It assumes that the + // required contexts and subelements have already been generated. + auto argTuple = std::tuple_cat([&](auto &&integral) { + constexpr std::size_t index = decltype(integral){}; + if constexpr (index == PARENT_DEPENDENCY) { + // Parent context. + return std::forward_as_tuple( + std::get - 1>(contexts) + ->template cast()); + } else { + // Subelement context + ASTNode. + return std::forward_as_tuple( + std::get(contexts)->template cast(), + *std::get(subElements)); + } + }(std::integral_constant{})...); + + return OpaqueContext(std::apply( + *std::any_cast< + TransferFn>(&dep), + std::move(argTuple))); + }) { + + static std::array storage{ + inputIndices...}; + this->inputIndices = storage; + } + + /// Returns the indices of the subelements (or parent) that this dependency + /// depends on. + llvm::ArrayRef getInputDependencies() const { + return inputIndices; + } + + /// Calculates the context from the currently calculated subelements and + /// contexts. Internal API that should only be used by the generator. + OpaqueContext operator()(const SubElementsTuple &subElements, + const ContextTuple &contexts) const { + return computationFn(dep, subElements, contexts); + } + +private: + std::any dep; + OpaqueContext (*computationFn)(const std::any &dep, + const SubElementsTuple &nonTerminals, + const ContextTuple &tuple); + llvm::ArrayRef inputIndices; +}; + +/// Array of transfer functions returned by 'AbstractTypeSystem' for every +/// 'ASTNode'. +/// The array contains as many elements as there are subelements in 'ASTNode' +/// plus one. +/// The corresponding index in the array corresponds to the 'OpaqueTransferFn' +/// instance used to calculate the input context for that subelement. +/// The special last element in the array corresponds to calculating the output +/// 'context' for the 'ASTNode'. +template +using TransferFnArray = + std::array, + std::tuple_size_v + 1>; + /// Abstract base class for all type systems. Users of a type system such as /// the C generator use this interface in conjunction with 'OpaqueContext' to be /// able to pass on contexts for generating AST elements without needing to know @@ -50,25 +266,43 @@ class OpaqueContext { /// /// The 'TypeSystem' base class below should be used instead to automate this by /// overriding all the methods in 'AbstractTypeSystem' that box and unbox -/// 'OpaqueContext's and dispatch to corresponding (non-opaque) 'check*' methods +/// 'OpaqueContext's and dispatch to corresponding (non-opaque) methods /// in the derived class. /// It also offers common and convenient default implementations of 'check*' -/// methods. +/// and 'discard*' methods. class AbstractTypeSystem { +protected: + /// Returns an instance of 'TransferFn' which simply forwards the context from + /// the parent to the subelement. + template + static auto copyFromParent() { + return TransferFn( + [](const OpaqueContext &context) { return context; }); + } + public: virtual ~AbstractTypeSystem(); virtual ConclusionOf checkFunctionOpaque(const OpaqueContext &context) = 0; - virtual std::optional> - checkBinaryExpressionOpaque(ast::BinaryExpression::Op op, - const OpaqueContext &context) = 0; + /// Returns true if the generator should discard this binary expression based + /// on the given input context. + virtual bool discardBinaryExpressionOpaque(ast::BinaryExpression::Op op, + const OpaqueContext &context) = 0; virtual std::optional> checkUnaryExpressionOpaque(ast::UnaryExpression::Op op, const OpaqueContext &context) = 0; + virtual TransferFnArray + getBinaryExpressionContextDependencies(ast::BinaryExpression::Op op) { + // Default implementation: Simply propagates the context to the subelements. + return {/*lhs=*/copyFromParent(), + /*rhs=*/copyFromParent(), + /*output=*/copyFromParent()}; + } + virtual std::optional> checkVariableOpaque(const OpaqueContext &context) = 0; @@ -93,8 +327,15 @@ class AbstractTypeSystem { checkScalarParameterOpaque(const ast::ScalarParameter &, const OpaqueContext &context) = 0; - virtual std::optional> - checkArrayReadExpressionOpaque(const OpaqueContext &context) = 0; + virtual bool + discardArrayReadExpressionOpaque(const OpaqueContext &context) = 0; + + virtual TransferFnArray + getArrayReadExpressionContextDependencies() { + return {/*array parameter=*/copyFromParent(), + /*index=*/copyFromParent(), + /*output=*/copyFromParent()}; + } virtual std::optional> checkArrayParameterOpaque(const ast::ArrayParameter &, @@ -115,20 +356,34 @@ class AbstractTypeSystem { /// used when generating sub-elements of an AST-node or 2) rejecting AST-nodes /// entirely based on the current type context. /// -/// All type checking is performed under a given context specified as the -/// 'TypingContext' template parameter. -/// For every AST construct a corresponding 'check*' method exists. -/// The input to this method is always the context used to type check the given -/// AST construct. -/// Based on the input context the 'check*' method can then derive new contexts -/// for its subelements or discard the AST-node entirely. -/// The return type is the so-called conclusion and is different for every -/// AST construct. It is specified using the 'TypeSystemTraits'. +/// There are currently two APIs to achieve this: +/// 1) The transfer functions API +/// 2) the 'check*' API. +/// The latter is considered deprecated and implements a subset of +/// functionality of the transfer functions API. +/// +/// Regardless of API, all type checking is performed under a given context +/// specified as the 'TypingContext' template parameter. +/// Every AST node is initially generated using an input context +/// passed into the 'check*' method or 'discard*' method of the AST node which +/// may discard the AST node. +/// Otherwise, new contexts for the subelements of the AST node can be derived. /// -/// E.g. the conclusion of a binary expression are the contexts that should be -/// used to type check the left and right operands. -/// Most 'check*' methods support discarding the AST node entirely, in which -/// case the conclusion type is wrapped in an optional. +/// The transfer functions API allows specifying how input contexts for AST +/// elements should be calculated. +/// Specifically, an instance of 'TransferFn' can specify that it depends on the +/// context and AST node of a sibling subelement in addition to, or instead of +/// the parent input context. +/// Example: +/// Given the C expression 'a[i]', an input context can be derived for +/// generating 'i' using knowledge gained from the output context and AST node +/// 'a'. +/// The generator uses this knowledge to generate the AST node of 'a' before +/// 'i'. +/// 'check*' methods in contrast only implement deriving subelement contexts +/// from the parent input context. +/// They return a tuple of 'TypingContext's for every subelement of 'ASTNode', +/// the so-called conclusion type. /// /// Note: We call it contexts rather than constraints to match literature, and /// as it more generally informs an AST-node generation about the type-system @@ -177,6 +432,9 @@ class TypeSystem : public AbstractTypeSystem { template using ConclusionOf = ConclusionOf; + template + using TransferFn = TransferFn; + /// Shorthand for derived classes to be able to call the default /// implementation of methods. using Super = TypeSystem; @@ -190,10 +448,9 @@ class TypeSystem : public AbstractTypeSystem { return {context, context}; } - static ConclusionOf - checkBinaryExpression(ast::BinaryExpression::Op, - const TypingContext &context) { - return {context, context}; + static bool discardBinaryExpression(ast::BinaryExpression::Op, + const TypingContext &) { + return false; } static ConclusionOf @@ -262,9 +519,8 @@ class TypeSystem : public AbstractTypeSystem { return context; } - static ConclusionOf - checkArrayReadExpression(const TypingContext &context) { - return {context, context}; + static bool discardArrayReadExpression(const TypingContext &) { + return false; } std::optional> @@ -293,11 +549,9 @@ class TypeSystem : public AbstractTypeSystem { return convert(self().checkFunction(context.cast())); } - std::optional> - checkBinaryExpressionOpaque(ast::BinaryExpression::Op op, - const OpaqueContext &context) final { - return convert( - self().checkBinaryExpression(op, context.cast())); + bool discardBinaryExpressionOpaque(ast::BinaryExpression::Op op, + const OpaqueContext &context) final { + return self().discardBinaryExpression(op, context.cast()); } std::optional> @@ -349,11 +603,8 @@ class TypeSystem : public AbstractTypeSystem { self().checkScalarParameter(node, context.cast())); } - std::optional< - dynamatic::ConclusionOf> - checkArrayReadExpressionOpaque(const OpaqueContext &context) final { - return convert( - self().checkArrayReadExpression(context.cast())); + bool discardArrayReadExpressionOpaque(const OpaqueContext &context) final { + return self().discardArrayReadExpression(context.cast()); } std::optional> @@ -441,9 +692,9 @@ template class DisallowByDefaultTypeSystem : public TypeSystem { public: - static std::optional> - checkBinaryExpression(ast::BinaryExpression::Op, const TypingContext &) { - return std::nullopt; + static bool discardBinaryExpression(ast::BinaryExpression::Op, + const TypingContext &) { + return true; } static std::optional> @@ -486,10 +737,7 @@ class DisallowByDefaultTypeSystem : public TypeSystem { return std::nullopt; } - static std::optional> - checkArrayReadExpression(const TypingContext &) { - return std::nullopt; - } + static bool discardArrayReadExpression(const TypingContext &) { return true; } std::optional> checkArrayParameter(const ast::ArrayParameter &, const TypingContext &) { diff --git a/tools/hls-fuzzer/TypeSystemTraits.h b/tools/hls-fuzzer/TypeSystemTraits.h index 772a893eb..e251b97ea 100644 --- a/tools/hls-fuzzer/TypeSystemTraits.h +++ b/tools/hls-fuzzer/TypeSystemTraits.h @@ -76,15 +76,6 @@ struct TypeSystemTraits : TypeSystemTraitsDefaults { using Conclusions = TypingContext; }; -template <> -struct TypeSystemTraits : TypeSystemTraitsDefaults { - - /// Type constraints for the left-hand operand followed by the right-hand - /// operand. - template - using Conclusions = std::tuple; -}; - template <> struct TypeSystemTraits : TypeSystemTraitsDefaults { @@ -137,14 +128,6 @@ struct TypeSystemTraits { using Conclusions = TypingContext; }; -template <> -struct TypeSystemTraits { - - template - using Conclusions = - std::tuple; -}; - template <> struct TypeSystemTraits { diff --git a/tools/hls-fuzzer/targets/BitwidthTypeSystem.cpp b/tools/hls-fuzzer/targets/BitwidthTypeSystem.cpp index 651517932..42370d7da 100644 --- a/tools/hls-fuzzer/targets/BitwidthTypeSystem.cpp +++ b/tools/hls-fuzzer/targets/BitwidthTypeSystem.cpp @@ -48,46 +48,75 @@ auto dynamatic::gen::BitwidthTypeSystem::checkConstant( constant.value); } -auto dynamatic::gen::BitwidthTypeSystem::checkBinaryExpression( - ast::BinaryExpression::Op op, const BitwidthTypingContext &context) const - -> std::optional> { +bool dynamatic::gen::BitwidthTypeSystem::discardBinaryExpression( + ast::BinaryExpression::Op op, const BitwidthTypingContext &context) const { switch (op) { - case ast::BinaryExpression::BitAnd: { - // Bitand is distributive: Sub-expressions can assume they are truncated - // as well. - std::optional req = context.bitwidthRequirementOrNone(); - if (!req) - return ConclusionOf{ResultIsTruncated{}, - ResultIsTruncated{}}; - - // Otherwise, one operand is constrained to of the given maximum bitwidth - // while the other can assume it is being truncated. - // The choice of whether the left or right-hand-side is constrained is - // arbitrary. - return ConclusionOf{ - ResultIsTruncated{}, getInterestingBitWidthInRange(*req)}; - } + case ast::BinaryExpression::BitAnd: + case ast::BinaryExpression::BitOr: + case ast::BinaryExpression::BitXor: + // Always allowed. + return false; + + case ast::BinaryExpression::ShiftRight: case ast::BinaryExpression::ShiftLeft: - // TODO: Left shift is distributive for the shifted operand but not the - // shift-amount. - // Under a fixed bitwidth, we can also choose bitwidths for both - // operands such that it fits within a fixed bitwidth. - return std::nullopt; + // TODO: Implement logic for these. + return true; case ast::BinaryExpression::Plus: case ast::BinaryExpression::Mul: case ast::BinaryExpression::Minus: - if (context.resultIsTruncated()) - return ConclusionOf{ResultIsTruncated{}, - ResultIsTruncated{}}; + // Only allowed if truncated. + return context.resultIsTruncated(); - // TODO: We can choose bitwidths for the left and right operands of these - // expressions here to fit a maximum bitwidth. - return std::nullopt; + case ast::BinaryExpression::Greater: + case ast::BinaryExpression::GreaterEqual: + case ast::BinaryExpression::Less: + case ast::BinaryExpression::LessEqual: + case ast::BinaryExpression::Equal: + case ast::BinaryExpression::NotEqual: + if (globalMaxBitwidth == 1) { + LLVM_DEBUG({ + llvm::dbgs() + << "Discarding NotEqualExpression as the maximum global " + "bitwidth == 1, which requires the comparison to be done " + "on 0-bit integers (which does not exist in C)"; + }); + return true; + } + return false; + } + llvm_unreachable("all enum cases handled"); +} - case ast::BinaryExpression::ShiftRight: - // TODO: Figure out constraints here. - return std::nullopt; +auto dynamatic::gen::BitwidthTypeSystem::getBinaryExpressionContextDependencies( + ast::BinaryExpression::Op op) -> TransferFnArray { + switch (op) { + case ast::BinaryExpression::BitAnd: + return { + /*lhs=*/TransferFn(ResultIsTruncated{}), + /*rhs=*/ + TransferFn( + [&](const BitwidthTypingContext &context) -> BitwidthTypingContext { + // Bitand is distributive: Sub-expressions can assume they are + // truncated as well. + std::optional req = + context.bitwidthRequirementOrNone(); + if (!req) + return ResultIsTruncated{}; + + return getInterestingBitWidthInRange(*req); + }), + /*output=*/copyFromParent(), + }; + + case ast::BinaryExpression::Plus: + case ast::BinaryExpression::Minus: + case ast::BinaryExpression::Mul: + return { + /*lhs=*/TransferFn(ResultIsTruncated{}), + /*rhs=*/TransferFn(ResultIsTruncated{}), + /*output=*/copyFromParent(), + }; case ast::BinaryExpression::Greater: case ast::BinaryExpression::GreaterEqual: case ast::BinaryExpression::Less: @@ -108,24 +137,20 @@ auto dynamatic::gen::BitwidthTypeSystem::checkBinaryExpression( // TODO: The sign-extension of the inputs is dependent on whether the type // of the operands are signed or not. We could track this // theoretically. - if (globalMaxBitwidth == 1) { - LLVM_DEBUG({ - llvm::dbgs() - << "Discarding NotEqualExpression as the maximum global " - "bitwidth == 1, which requires the comparison to be done " - "on 0-bit integers (which does not exist in C)"; - }); - return std::nullopt; - } - - return ConclusionOf{ - {getInterestingBitWidthInRange(globalMaxBitwidth - 1)}, - {getInterestingBitWidthInRange(globalMaxBitwidth - 1)}}; + return { + /*lhs=*/TransferFn( + getInterestingBitWidthInRange(globalMaxBitwidth - 1)), + /*rhs=*/ + TransferFn( + getInterestingBitWidthInRange(globalMaxBitwidth - 1)), + /*parent=*/copyFromParent(), + }; case ast::BinaryExpression::BitOr: case ast::BinaryExpression::BitXor: - // Distribute regarding truncation. - return ConclusionOf{context, context}; + case ast::BinaryExpression::ShiftLeft: + case ast::BinaryExpression::ShiftRight: + return TypeSystem::getBinaryExpressionContextDependencies(op); } llvm_unreachable("all enum cases handled"); } @@ -148,6 +173,25 @@ auto dynamatic::gen::BitwidthTypeSystem::checkFunction( }; } +auto dynamatic::gen::BitwidthTypeSystem:: + getArrayReadExpressionContextDependencies() + -> TransferFnArray { + return { + /*array parameter=*/copyFromParent(), + /*index=*/ + TransferFn( + [&](const BitwidthTypingContext &, + const ast::ArrayParameter ¶meter) { + assert(llvm::isPowerOf2_64(parameter.getDimension()) && + "implementation depends on dimensions being powers of 2"); + return BitwidthTypingContext{std::min( + llvm::Log2_64(parameter.getDimension()), globalMaxBitwidth)}; + }), + /*output=*/copyFromParent(), + }; +} + dynamatic::gen::BitwidthTypingContext dynamatic::gen::BitwidthTypeSystem::getInterestingBitWidthInRange( uint8_t bitWidth) const { diff --git a/tools/hls-fuzzer/targets/BitwidthTypeSystem.h b/tools/hls-fuzzer/targets/BitwidthTypeSystem.h index dbc99a55e..6675cd654 100644 --- a/tools/hls-fuzzer/targets/BitwidthTypeSystem.h +++ b/tools/hls-fuzzer/targets/BitwidthTypeSystem.h @@ -88,9 +88,8 @@ class BitwidthTypeSystem checkConstant(const ast::Constant &constant, const BitwidthTypingContext &context) const; - std::optional> - checkBinaryExpression(ast::BinaryExpression::Op op, - const BitwidthTypingContext &context) const; + bool discardBinaryExpression(ast::BinaryExpression::Op op, + const BitwidthTypingContext &context) const; static std::optional> checkUnaryExpression(ast::UnaryExpression::Op, @@ -99,12 +98,18 @@ class BitwidthTypeSystem return std::nullopt; } + TransferFnArray + getBinaryExpressionContextDependencies(ast::BinaryExpression::Op op) final; + ConclusionOf checkConditionalExpression(const BitwidthTypingContext &context) const; static ConclusionOf checkFunction(const BitwidthTypingContext &context); + TransferFnArray + getArrayReadExpressionContextDependencies() override; + private: /// Returns either 'bitWidth' or with a low probability, a value in the range /// [1, bitWidth]. diff --git a/tools/hls-fuzzer/targets/DynamaticTypeSystem.cpp b/tools/hls-fuzzer/targets/DynamaticTypeSystem.cpp index 75bda1d00..ca6bc670b 100644 --- a/tools/hls-fuzzer/targets/DynamaticTypeSystem.cpp +++ b/tools/hls-fuzzer/targets/DynamaticTypeSystem.cpp @@ -19,9 +19,8 @@ auto dynamatic::gen::DynamaticTypeSystem::checkScalarType( llvm_unreachable("all enum cases handled"); } -auto dynamatic::gen::DynamaticTypeSystem::checkBinaryExpression( - ast::BinaryExpression::Op op, DynamaticTypingContext context) - -> std::optional> { +bool dynamatic::gen::DynamaticTypeSystem::discardBinaryExpression( + ast::BinaryExpression::Op op, DynamaticTypingContext context) { switch (op) { case ast::BinaryExpression::BitAnd: case ast::BinaryExpression::BitOr: @@ -29,14 +28,8 @@ auto dynamatic::gen::DynamaticTypeSystem::checkBinaryExpression( case ast::BinaryExpression::ShiftLeft: case ast::BinaryExpression::ShiftRight: // Bit expressions always yield integer types. - if (context.constraint == DynamaticTypingContext::FloatRequired) - return std::nullopt; + return context.constraint == DynamaticTypingContext::FloatRequired; - // Operands must be integer types. - return ConclusionOf{ - {DynamaticTypingContext::IntegerRequired}, - {DynamaticTypingContext::IntegerRequired}, - }; case ast::BinaryExpression::Greater: case ast::BinaryExpression::GreaterEqual: case ast::BinaryExpression::Less: @@ -44,18 +37,37 @@ auto dynamatic::gen::DynamaticTypeSystem::checkBinaryExpression( case ast::BinaryExpression::Equal: case ast::BinaryExpression::NotEqual: // Equality operations always yield 'int'. - if (context.constraint == DynamaticTypingContext::FloatRequired) - return std::nullopt; - [[fallthrough]]; - + return context.constraint == DynamaticTypingContext::FloatRequired; case ast::BinaryExpression::Plus: case ast::BinaryExpression::Minus: case ast::BinaryExpression::Mul: - return Super::checkBinaryExpression(op, context); + return false; } llvm_unreachable("all enum values handled"); } +dynamatic::gen::TransferFnArray +dynamatic::gen::DynamaticTypeSystem::getBinaryExpressionContextDependencies( + ast::BinaryExpression::Op op) { + switch (op) { + case ast::BinaryExpression::BitAnd: + case ast::BinaryExpression::BitOr: + case ast::BinaryExpression::BitXor: + case ast::BinaryExpression::ShiftLeft: + case ast::BinaryExpression::ShiftRight: + return {/*lhs=*/TransferFn(DynamaticTypingContext{ + DynamaticTypingContext::IntegerRequired}), + /*rhs=*/ + TransferFn(DynamaticTypingContext{ + DynamaticTypingContext::IntegerRequired}), + /*output=*/ + TransferFn(DynamaticTypingContext{ + DynamaticTypingContext::IntegerRequired})}; + default: + return Super::getBinaryExpressionContextDependencies(op); + } +} + auto dynamatic::gen::DynamaticTypeSystem::checkUnaryExpression( ast::UnaryExpression::Op op, DynamaticTypingContext context) const -> std::optional> { diff --git a/tools/hls-fuzzer/targets/DynamaticTypeSystem.h b/tools/hls-fuzzer/targets/DynamaticTypeSystem.h index 486400f0c..551baf089 100644 --- a/tools/hls-fuzzer/targets/DynamaticTypeSystem.h +++ b/tools/hls-fuzzer/targets/DynamaticTypeSystem.h @@ -34,14 +34,16 @@ class DynamaticTypeSystem /// Discard 'op' based on the mode in 'context' and forward constraint to /// the operands as required. - static std::optional> - checkBinaryExpression(ast::BinaryExpression::Op op, - DynamaticTypingContext context); + static bool discardBinaryExpression(ast::BinaryExpression::Op op, + DynamaticTypingContext context); std::optional> checkUnaryExpression(ast::UnaryExpression::Op op, DynamaticTypingContext context) const; + TransferFnArray + getBinaryExpressionContextDependencies(ast::BinaryExpression::Op op) final; + ConclusionOf checkConditionalExpression(DynamaticTypingContext context) const { // Condition can be either a floating point type or integer type. @@ -53,14 +55,13 @@ class DynamaticTypeSystem }; } - static std::optional> - checkArrayReadExpression(DynamaticTypingContext context) { - return ConclusionOf{ - // Forward the context to the array parameter as is. - context, - // Indexing expression must be an integer. - DynamaticTypingContext{DynamaticTypingContext::IntegerRequired}, - }; + TransferFnArray + getArrayReadExpressionContextDependencies() final { + return {/*array parameter=*/copyFromParent(), + /*index=*/ + TransferFn(DynamaticTypingContext{ + DynamaticTypingContext::IntegerRequired}), + /*output=*/copyFromParent()}; } static std::optional> diff --git a/unittests/tools/hls-fuzzer/TEST_SUITE.cpp b/unittests/tools/hls-fuzzer/TEST_SUITE.cpp index 6f9dc51c9..2261d7854 100644 --- a/unittests/tools/hls-fuzzer/TEST_SUITE.cpp +++ b/unittests/tools/hls-fuzzer/TEST_SUITE.cpp @@ -26,17 +26,24 @@ REGISTER_TYPED_TEST_SUITE_P(TypeSystemTest, OutputCheck); namespace { // Bool representing whether a parameter is required. -class PlusOfTwoParamOnlyTypeSystem +class PlusOfTwoParamOnlyTypeSystem final : public gen::DisallowByDefaultTypeSystem { public: - static std::optional> - checkBinaryExpression(ast::BinaryExpression::Op op, bool mustBeParameter) { - // Saw a binop, parameter is now required. - if (!mustBeParameter && op == ast::BinaryExpression::Plus) - return ConclusionOf{true, true}; + using DisallowByDefaultTypeSystem::DisallowByDefaultTypeSystem; - return std::nullopt; + static bool discardBinaryExpression(ast::BinaryExpression::Op op, + bool mustBeParameter) { + return mustBeParameter || op != ast::BinaryExpression::Plus; + } + + gen::TransferFnArray + getBinaryExpressionContextDependencies(ast::BinaryExpression::Op) override { + return { + TransferFn(true), + TransferFn(true), + TransferFn(true), + }; } static std::optional> @@ -73,15 +80,23 @@ class PlusOfTwoParamOnlyTypeSystem // Bool representing whether an array read expression is required. // Otherwise, a 0 constant must be generated. -class ReturnArrayConstantOnlyTypeSystem +class ReturnArrayConstantOnlyTypeSystem final : public gen::DisallowByDefaultTypeSystem< /*createArrayRead=*/bool, ReturnArrayConstantOnlyTypeSystem> { public: - static std::optional> - checkArrayReadExpression(bool createArrayRead) { - if (!createArrayRead) - return std::nullopt; - return ConclusionOf{false, false}; + using DisallowByDefaultTypeSystem::DisallowByDefaultTypeSystem; + + static bool discardArrayReadExpression(bool createArrayRead) { + return !createArrayRead; + } + + gen::TransferFnArray + getArrayReadExpressionContextDependencies() override { + return { + TransferFn(false), + TransferFn(false), + copyFromParent(), + }; } std::optional> @@ -112,8 +127,8 @@ class ReturnArrayConstantOnlyTypeSystem } constexpr static std::string_view result = - R"(double test(double var0[16]) { - return var0[((uint32_t)((0)) & (15u))]; + R"(double test(double var0[1]) { + return var0[((uint32_t)((0)) & (0u))]; } )";