diff --git a/tools/hls-fuzzer/AST.cpp b/tools/hls-fuzzer/AST.cpp index 3820c5302..be499e76e 100644 --- a/tools/hls-fuzzer/AST.cpp +++ b/tools/hls-fuzzer/AST.cpp @@ -379,6 +379,11 @@ llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os, << parameter.getDimension() << ']'; } +llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os, + const ReturnType &returnType) { + return os << returnType.variant; +} + llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os, const Function &function) { os << function.returnType << ' ' << function.name << '('; @@ -391,10 +396,10 @@ llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os, mlir::raw_indented_ostream indentedOstream(os); indentedOstream.indent(); for (auto &iter : function.statements) - indentedOstream << iter; + indentedOstream << iter << '\n'; if (function.returnStatement) - indentedOstream << *function.returnStatement; + indentedOstream << *function.returnStatement << '\n'; - os << "\n}\n"; + os << "}\n"; return os; } diff --git a/tools/hls-fuzzer/AST.h b/tools/hls-fuzzer/AST.h index 449f8e31a..8b0be2814 100644 --- a/tools/hls-fuzzer/AST.h +++ b/tools/hls-fuzzer/AST.h @@ -192,6 +192,8 @@ class ScalarType { template friend struct llvm::simplify_type; + using SubElements = std::tuple<>; + private: std::shared_ptr datatype; }; @@ -234,16 +236,22 @@ struct Constant { return PrimitiveType::Double; }); } + + using SubElements = std::tuple<>; }; llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Constant &constant); +class ScalarParameter; + /// AST-node representing a reference to a variable in C. struct Variable { const ScalarType datatype; const std::string name; const ScalarType &getType() const { return datatype; } + + using SubElements = std::tuple; }; llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variable &variable); @@ -324,6 +332,9 @@ class BinaryExpression { using SubElements = std::tuple; + constexpr static std::size_t LHS = 0; + constexpr static std::size_t RHS = 1; + private: Expression lhs; Op op; @@ -344,6 +355,11 @@ class CastExpression { /// Returns the type of this expression, i.e., the type being cast to. const ScalarType &getType() const { return targetType; } + using SubElements = std::tuple; + + constexpr static std::size_t TARGET_TYPE = 0; + constexpr static std::size_t OPERAND = 1; + private: ScalarType targetType; Expression expression; @@ -374,6 +390,8 @@ class UnaryExpression { static bool isLegalOperandType(Op op, const ScalarType &type); + using SubElements = std::tuple; + private: Op op; Expression expression; @@ -398,6 +416,11 @@ class ConditionalExpression { ScalarType getType() const; + using SubElements = std::tuple; + constexpr static std::size_t CONDITION = 0; + constexpr static std::size_t TRUE_VAL = 1; + constexpr static std::size_t FALSE_VAL = 2; + private: Expression condition; Expression trueVal; @@ -452,6 +475,8 @@ class ReturnStatement { const Expression &getReturnValue() const { return returnValue; } + using SubElements = std::tuple; + private: Expression returnValue; }; @@ -478,6 +503,11 @@ class ArrayAssignmentStatement { /// Returns the value that will be assigned to the element. const Expression &getValueExpression() const { return valueExpression; } + using SubElements = std::tuple; + constexpr static std::size_t ARRAY = 0; + constexpr static std::size_t INDEX = 1; + constexpr static std::size_t VALUE = 2; + private: std::string arrayParameter; Expression indexingExpression; @@ -507,6 +537,40 @@ class Statement { std::shared_ptr statement; }; +/// Class representing a list of statements in a body. +class StatementList { +public: + StatementList() = default; + + explicit StatementList(std::vector statements) + : statements(std::move(statements)) {} + + /// Returns the number of statements. + std::size_t size() const { return statements.size(); } + + const Statement &operator[](std::size_t index) const { + return statements[index]; + } + + auto begin() { return statements.begin(); } + + auto begin() const { return statements.begin(); } + + auto end() { return statements.end(); } + + auto end() const { return statements.end(); } + + std::vector takeVector() { return std::move(statements); } + + // Recursive statement list representation. + // The definition is left recursive, meaning the statement is always the tail + // statement after the list. + using SubElements = std::tuple; + +private: + std::vector statements; +}; + /// AST-Node representing a scalar function parameter in C. class ScalarParameter { public: @@ -519,6 +583,8 @@ class ScalarParameter { const ScalarType &getDataType() const { return dataType; } + using SubElements = std::tuple; + private: ScalarType dataType; std::string name; @@ -545,6 +611,8 @@ class ArrayParameter { std::size_t getDimension() const { return dimension; } + using SubElements = std::tuple; + private: ScalarType dataType; std::string name; @@ -555,13 +623,50 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ArrayParameter ¶meter); /// Tag type representing the 'void' type from C. -struct VoidType {}; +struct VoidType { + friend bool operator==(const VoidType &lhs, const VoidType &rhs) { + return true; + } + + friend bool operator!=(const VoidType &lhs, const VoidType &rhs) { + return !(lhs == rhs); + } +}; inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const VoidType &) { return os << "void"; } -using ReturnType = std::variant; +class ReturnType { + using Variant = std::variant; + +public: + ReturnType() = default; + + template && + !std::is_same_v, T>> + * = nullptr> + /*implicit*/ ReturnType(T &&arg) : variant(std::forward(arg)) {} + + friend bool operator==(const ReturnType &lhs, const ReturnType &rhs) { + return lhs.variant == rhs.variant; + } + + friend bool operator!=(const ReturnType &lhs, const ReturnType &rhs) { + return !(lhs == rhs); + } + + template + friend struct llvm::simplify_type; + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ReturnType &returnType); + + using SubElements = std::tuple<>; + +private: + Variant variant; +}; /// AST-Node representing a function in C. /// Functions are currently limited to just a return statement. @@ -576,6 +681,11 @@ struct Function { /// The return statement at the end of a function iff it does not have a void /// return type. const std::optional returnStatement; + + using SubElements = std::tuple; + constexpr static std::size_t RETURN_TYPE = 0; + constexpr static std::size_t STATEMENTS = 1; + constexpr static std::size_t RETURN_STATEMENT = 2; }; llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &function); @@ -628,6 +738,7 @@ auto enumRange() { // Enable 'dyn_cast' and friends on 'ScalarType' by delegating to 'dyn_cast' on // the variant. +// E.g. 'ast::PrimitiveType* prim = dyn_cast(scalarType);' template <> struct llvm::simplify_type { using SimpleType = const dynamatic::ast::ScalarType::Variant; @@ -638,4 +749,14 @@ struct llvm::simplify_type { } }; +template <> +struct llvm::simplify_type { + using SimpleType = const dynamatic::ast::ReturnType::Variant; + + static SimpleType & + getSimplifiedValue(const dynamatic::ast::ReturnType &datatype) { + return datatype.variant; + } +}; + #endif diff --git a/tools/hls-fuzzer/BasicCGenerator.cpp b/tools/hls-fuzzer/BasicCGenerator.cpp index 416a91fca..2e7024f5b 100644 --- a/tools/hls-fuzzer/BasicCGenerator.cpp +++ b/tools/hls-fuzzer/BasicCGenerator.cpp @@ -49,19 +49,23 @@ static ast::Expression safeCastAsNeeded(const ast::ScalarType &to, outputPrim->getMinValue()); } -auto gen::BasicCGenerator::generateFreshScalarParameter( - ast::ScalarType datatype, const OpaqueContext &context) - -> PendingParameter { - scalarParameters.push_back( - {{std::move(datatype), generateFreshVarName()}, context}); - return PendingParameter(*this, scalarParameters.back().first); -} - ast::ReturnStatement gen::BasicCGenerator::generateReturnStatement(const OpaqueContext &context) { - ast::Expression expression = generateExpression(context, 0); - return ast::ReturnStatement{safeCastAsNeeded( - llvm::cast(returnType), std::move(expression))}; + return *generateWithDependencies( + context, typeSystem.getReturnStatementTransferFns(), + /*return value=*/ + [&](const OpaqueContext &context) { + ast::Expression expression = generateExpression(context, 0); + if (maybeReturnType && llvm::isa(*maybeReturnType)) + expression = + safeCastAsNeeded(llvm::cast(*maybeReturnType), + std::move(expression)); + return expression; + }, + /*constructor=*/ + [&](ast::Expression &&expression) { + return ast::ReturnStatement{std::move(expression)}; + }); } constexpr std::size_t MAX_DEPTH = 4; @@ -127,7 +131,7 @@ gen::BasicCGenerator::generateBinaryExpression(ast::BinaryExpression::Op op, return std::nullopt; return generateWithDependencies( - context, typeSystem.getBinaryExpressionContextDependencies(op), + context, typeSystem.getBinaryExpressionTransferFns(op), /*lhs=*/ [&](const OpaqueContext &context) -> ast::Expression { return generateExpression(context, depth + 1); @@ -211,102 +215,127 @@ std::optional gen::BasicCGenerator::generateUnaryExpression(ast::UnaryExpression::Op op, const OpaqueContext &context, std::size_t depth) { - auto conclusion = typeSystem.checkUnaryExpressionOpaque(op, context); - if (!conclusion) + if (typeSystem.discardUnaryExpressionOpaque(op, context)) return std::nullopt; - ast::Expression operand = generateExpression(*conclusion, depth + 1); - - // Perform explicit casts to a legal operand type if the operand type is not - // legal for the given operation. - // This would e.g. cast 'double's that are meant to be applied to '~' to a - // random type that can be legally used with '~'. - if (!ast::UnaryExpression::isLegalOperandType(op, operand.getType())) { - std::optional scalarType = generateScalarType( - context, /*toExclude=*/[&](const ast::ScalarType &value) { - return !ast::UnaryExpression::isLegalOperandType(op, value); - }); - if (!scalarType) - return std::nullopt; - - operand = safeCastAsNeeded(*scalarType, std::move(operand)); - } + return generateWithDependencies( + context, typeSystem.getUnaryExpressionTransferFns(op), + /*operand=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, depth + 1); + }, + /*constructor=*/ + [&](ast::Expression &&operand) -> std::optional { + // Perform explicit casts to a legal operand type if the operand type is + // not legal for the given operation. This would e.g. cast 'double's + // that are meant to be applied to '~' to a random type that can be + // legally used with '~'. + if (!ast::UnaryExpression::isLegalOperandType(op, operand.getType())) { + std::optional scalarType = generateScalarType( + context, /*toExclude=*/[&](const ast::ScalarType &value) { + return !ast::UnaryExpression::isLegalOperandType(op, value); + }); + if (!scalarType) + return std::nullopt; - return ast::UnaryExpression{op, std::move(operand)}; + operand = safeCastAsNeeded(*scalarType, std::move(operand)); + } + + return ast::UnaryExpression{op, std::move(operand)}; + }); } std::optional gen::BasicCGenerator::generateConditionalExpression( const OpaqueContext &context, std::size_t depth) { - auto subcontext = typeSystem.checkConditionalExpressionOpaque(context); - if (!subcontext) + if (typeSystem.discardConditionalExpressionOpaque(context)) return std::nullopt; - auto &&[cond, trueExpr, falseExpr] = *subcontext; - return ast::ConditionalExpression{generateExpression(cond, depth + 1), - generateExpression(trueExpr, depth + 1), - generateExpression(falseExpr, depth + 1)}; + return generateWithDependencies( + context, typeSystem.getConditionalExpressionTransferFns(), + /*condition=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, depth + 1); + }, + /*true value=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, depth + 1); + }, + /*false value=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, depth + 1); + }, + /*constructor=*/ + [&](ast::Expression &&cond, ast::Expression &&trueExpr, + ast::Expression &&falseExpr) { + return ast::ConditionalExpression{std::move(cond), std::move(trueExpr), + std::move(falseExpr)}; + }); } std::optional gen::BasicCGenerator::generateCastExpression(const OpaqueContext &context, std::size_t depth) { - auto subcontext = typeSystem.checkCastExpressionOpaque(context); - if (!subcontext) + if (typeSystem.discardCastExpressionOpaque(context)) return std::nullopt; - auto &&[typeCon, exprCon] = *subcontext; - ast::Expression expression = generateExpression(exprCon, depth + 1); - ast::ScalarType expressionType = expression.getType(); - - // Keep it interesting by not performing noop-casts! - std::optional datatype = - generateScalarType(typeCon, /*toExclude=*/[&](auto &&value) { - return value == expressionType; + return generateWithDependencies( + context, typeSystem.getCastExpressionTransferFns(), + /*data type=*/ + [&](const OpaqueContext &context) { return generateScalarType(context); }, + /*operand=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, depth + 1); + }, + /*constructor=*/ + [](ast::ScalarType &&datatype, ast::Expression &&expression) { + return ast::CastExpression{std::move(datatype), std::move(expression)}; }); - if (!datatype) - return std::nullopt; - - return ast::CastExpression{std::move(*datatype), std::move(expression)}; } -std::optional -gen::BasicCGenerator::generateConstant(const OpaqueContext &context, - std::size_t) const { - auto candidates = ast::PrimitiveType::ALL_PRIMITIVES; - random.shuffle(candidates); +ast::Constant gen::BasicCGenerator::getConstantForType( + const ast::ScalarType &scalarType) const { + return llvm::TypeSwitch(scalarType) + .Case([&](const ast::PrimitiveType *primitive) { + switch (primitive->getType()) { + case ast::PrimitiveType::Int8: + return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::UInt8: + return ast::Constant{random.getInterestingInteger()}; - for (ast::PrimitiveType::Type iter : candidates) { - std::optional constant = [&] { - switch (iter) { - case ast::PrimitiveType::Int8: - return ast::Constant{random.getInterestingInteger()}; - case ast::PrimitiveType::UInt8: - return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::Int16: + return ast::Constant{random.getInterestingInteger()}; - case ast::PrimitiveType::Int16: - return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::UInt16: + return ast::Constant{random.getInterestingInteger()}; - case ast::PrimitiveType::UInt16: - return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::Int32: + return ast::Constant{random.getInterestingInteger()}; - case ast::PrimitiveType::Int32: - return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::UInt32: + return ast::Constant{random.getInterestingInteger()}; - case ast::PrimitiveType::UInt32: - return ast::Constant{random.getInterestingInteger()}; + case ast::PrimitiveType::Float: + return ast::Constant{random.getInterestingFloat()}; - case ast::PrimitiveType::Float: - return ast::Constant{random.getInterestingFloat()}; + case ast::PrimitiveType::Double: + return ast::Constant{random.getInterestingDouble()}; + } + llvm_unreachable("all enum cases handled"); + }); +} - case ast::PrimitiveType::Double: - return ast::Constant{random.getInterestingDouble()}; - } - llvm_unreachable("all enum cases handled"); - }(); - if (constant = typeSystem.checkConstantOpaque(*constant, context); constant) +std::optional +gen::BasicCGenerator::generateConstant(const OpaqueContext &context, + std::size_t) const { + auto candidates = ast::PrimitiveType::ALL_PRIMITIVES; + random.shuffle(candidates); + + for (ast::PrimitiveType::Type iter : candidates) + if (std::optional constant = + typeSystem.discardConstantOpaque(getConstantForType(iter), context)) return constant; - } + return std::nullopt; } @@ -317,7 +346,7 @@ gen::BasicCGenerator::generateArrayReadExpression(const OpaqueContext &context, return std::nullopt; return generateWithDependencies( - context, typeSystem.getArrayReadExpressionContextDependencies(), + context, typeSystem.getArrayReadExpressionTransferFns(), /*array parameter=*/ [&](const OpaqueContext &context) -> std::optional { return generateArrayParameter(context); @@ -360,68 +389,88 @@ gen::BasicCGenerator::generateArrayParameter(const OpaqueContext &context, if (!random.getRatherLowProbabilityBool()) { // Randomly shuffle the parameter ordering and find the first parameter // that passes type checking. - std::vector copy(arrayParameters.size()); - llvm::copy(llvm::make_first_range(arrayParameters), copy.begin()); + std::vector copy = arrayParameters; random.shuffle(copy); for (const ast::ArrayParameter &candidateParam : copy) - if (typeSystem.checkArrayParameterOpaque(candidateParam, context)) + if (!typeSystem.discardExistingArrayParameterOpaque(candidateParam, + context)) return candidateParam; } - std::optional elementType = generateScalarType(context); - if (!elementType) + if (typeSystem.discardFreshArrayParameterOpaque(context)) return std::nullopt; - arrayParameters.push_back( - {{std::move(*elementType), generateFreshVarName(), - // Generate a power-of-2 dimension to make the modulo operator fast and - // easy to implement. - // We choose an arbitrary upper-bound of 32 for the dimension for now. - static_cast(1 << random.getInteger(0, 5))}, - context}); - if (!typeSystem.checkArrayParameterOpaque(arrayParameters.back().first, - context)) { - arrayParameters.pop_back(); - varCounter--; - return std::nullopt; - } - return arrayParameters.back().first; + return generateWithDependencies( + context, typeSystem.getArrayParameterTransferFns(), + /*element type=*/ + [&](const OpaqueContext &context) { return generateScalarType(context); }, + /*constructor=*/ + [&](ast::ScalarType &&elementType) { + return arrayParameters.emplace_back( + std::move(elementType), generateFreshVarName(), + // Generate a power-of-2 dimension to make the modulo operator + // fast and easy to implement. We choose an arbitrary upper-bound + // of 32 for the dimension for now. + static_cast(1 << random.getInteger(0, 5))); + }); } std::optional gen::BasicCGenerator::generateScalarParameter(const OpaqueContext &context, std::size_t) { - auto conclusion = typeSystem.checkVariableOpaque(context); - if (!conclusion) + if (typeSystem.discardVariableOpaque(context)) return std::nullopt; - // With a low chance, skip picking an existing parameter and try to generate - // a new one. - if (!random.getRatherLowProbabilityBool()) { - // Randomly shuffle the parameter ordering and find the first parameter - // that passes type checking. - std::vector copy(scalarParameters.size()); - llvm::copy(llvm::make_first_range(scalarParameters), copy.begin()); - random.shuffle(copy); + return generateWithDependencies( + context, typeSystem.getVariableTransferFns(), + /*parameter=*/ + [&](const OpaqueContext &context) -> std::optional { + std::array()>, 2> + generators; + generators[0] = [&]() -> std::optional { + // Randomly shuffle the parameter ordering and find the first + // parameter that passes type checking. + std::vector copy = scalarParameters; + random.shuffle(copy); + + for (const ast::ScalarParameter &iter : copy) + if (!typeSystem.discardExistingScalarParameterOpaque(iter, context)) + return iter; + + return std::nullopt; + }; + generators[1] = [&]() -> std::optional { + if (typeSystem.discardFreshScalarParameterOpaque(context)) + return std::nullopt; - for (ast::ScalarParameter &iter : copy) - if (typeSystem.checkScalarParameterOpaque(iter, *conclusion)) - return ast::Variable{iter.getDataType(), iter.getName().str()}; - } + return generateWithDependencies( + context, typeSystem.getScalarParameterTransferFns(), + /*datatype=*/ + [&](const OpaqueContext &context) { + return generateScalarType(context); + }, + /*constructor=*/ + [&](ast::ScalarType &&datatype) { + return scalarParameters.emplace_back(std::move(datatype), + generateFreshVarName()); + }); + }; - std::optional datatype = generateScalarType(*conclusion); - if (!datatype) - return std::nullopt; + if (random.getRatherLowProbabilityBool()) + std::swap(generators[0], generators[1]); - PendingParameter pendingParam = - generateFreshScalarParameter(*datatype, context); - if (typeSystem.checkScalarParameterOpaque(pendingParam.getParameter(), - *conclusion)) { - ast::ScalarParameter parameter = pendingParam.commit(); - return ast::Variable{parameter.getDataType(), parameter.getName().str()}; - } - return std::nullopt; + for (auto &iter : generators) + if (std::optional result = iter()) + return result; + + return std::nullopt; + }, + /*constructor=*/ + [&](ast::ScalarParameter &¶meter) { + return ast::Variable{parameter.getDataType(), + parameter.getName().str()}; + }); } std::optional gen::BasicCGenerator::generateScalarType( @@ -434,7 +483,7 @@ std::optional gen::BasicCGenerator::generateScalarType( if (toExclude && toExclude(iter)) continue; - if (typeSystem.checkScalarTypeOpaque(iter, context)) + if (!typeSystem.discardScalarTypeOpaque(iter, context)) return iter; } @@ -451,7 +500,7 @@ gen::BasicCGenerator::generateReturnType(const OpaqueContext &context) const { candidates.back() = ast::VoidType{}; random.shuffle(candidates); for (const ast::ReturnType &iter : candidates) - if (typeSystem.checkReturnTypeOpaque(iter, context)) + if (!typeSystem.discardReturnTypeOpaque(iter, context)) return iter; llvm::report_fatal_error( @@ -460,23 +509,29 @@ gen::BasicCGenerator::generateReturnType(const OpaqueContext &context) const { constexpr std::size_t MAX_STATEMENTS = 10; -std::vector -gen::BasicCGenerator::generateStatementList(const OpaqueContext &context) { - std::vector result; - // TODO: Type systems should have better control over the number of - // statements and in what order they are generated. - // Right now they are always generated last statement to first. - std::size_t numStatements = random.getInteger(0, MAX_STATEMENTS); - result.reserve(numStatements); - for (std::size_t i = 0; i < numStatements; i++) { - std::optional maybeStat = generateStatement(context); - if (!maybeStat) - break; - - result.push_back(std::move(*maybeStat)); - } - std::reverse(result.begin(), result.end()); - return result; +ast::StatementList +gen::BasicCGenerator::generateStatementList(const OpaqueContext &context, + size_t depth) { + if (depth > MAX_STATEMENTS) + return ast::StatementList(); + + return generateWithDependencies( + context, typeSystem.getStatementListTransferFns(), + /*statement list=*/ + [&](const OpaqueContext &context) { + return generateStatementList(context, depth + 1); + }, + /*statement=*/ + [&](const OpaqueContext &context) { + return generateStatement(context); + }, + /*constructor=*/ + [&](ast::StatementList &&statements, ast::Statement &&statement) { + std::vector result = statements.takeVector(); + result.push_back(std::move(statement)); + return ast::StatementList(std::move(result)); + }) + .value_or(ast::StatementList()); } std::optional @@ -487,48 +542,68 @@ gen::BasicCGenerator::generateStatement(const OpaqueContext &context) { std::optional gen::BasicCGenerator::generateArrayAssignmentStatement( const OpaqueContext &context) { - auto conclusion = typeSystem.checkArrayAssignmentStatementOpaque(context); - if (!conclusion) + if (typeSystem.discardArrayAssignmentStatementOpaque(context)) return std::nullopt; - auto &&[param, index, value] = *conclusion; - std::optional parameter = generateArrayParameter(param); - if (!parameter) - return std::nullopt; - - ast::Expression castAsNeeded = safeCastAsNeeded( - /*to=*/ast::PrimitiveType::UInt32, - generateExpression(/*context=*/index, /*depth=*/0)); - castAsNeeded = ast::BinaryExpression{ - std::move(castAsNeeded), ast::BinaryExpression::BitAnd, - ast::Constant{static_cast(parameter->getDimension() - 1)}}; - return ast::ArrayAssignmentStatement{ - parameter->getName().str(), - castAsNeeded, - generateExpression(value, 0), - }; + return generateWithDependencies( + context, typeSystem.getArrayAssignmentStatementTransferFns(), + /*array parameter=*/ + [&](const OpaqueContext &context) { + return generateArrayParameter(context); + }, + /*index=*/ + [&](const OpaqueContext &context) { + return safeCastAsNeeded( + /*to=*/ast::PrimitiveType::UInt32, + generateExpression(context, /*depth=*/0)); + }, + /*value=*/ + [&](const OpaqueContext &context) { + return generateExpression(context, 0); + }, + /*constructor=*/ + [&](ast::ArrayParameter &¶m, ast::Expression &&index, + ast::Expression &&value) { + index = ast::BinaryExpression{std::move(index), + ast::BinaryExpression::BitAnd, + ast::Constant{static_cast( + param.getDimension() - 1)}}; + return ast::ArrayAssignmentStatement{ + param.getName().str(), + std::move(index), + std::move(value), + }; + }); } ast::Function gen::BasicCGenerator::generate(std::string_view functionName) { - auto conclusion = typeSystem.checkFunctionOpaque(entryContext); - returnType = generateReturnType(conclusion.returnType); - std::optional returnStatement; - if (!std::holds_alternative(returnType)) - returnStatement = generateReturnStatement(conclusion.returnStatement); - - std::vector statementList = - generateStatementList(conclusion.returnStatement); - - auto scalarRange = llvm::make_first_range(scalarParameters); - auto arrayRange = llvm::make_first_range(arrayParameters); - return ast::Function{ - returnType, - std::string(functionName), - std::vector(scalarRange.begin(), scalarRange.end()), - std::vector(arrayRange.begin(), arrayRange.end()), - statementList, - std::move(returnStatement), - }; + return *generateWithDependencies( + entryContext, typeSystem.getFunctionTransferFns(), + /*return type=*/ + [&](const OpaqueContext &context) { + return maybeReturnType = generateReturnType(context); + }, + /*statement list=*/ + [&](const OpaqueContext &context) { + return generateStatementList(context, 0); + }, + /*return statement=*/ + [&](const OpaqueContext &context) { + return generateReturnStatement(context); + }, + /*constructor=*/ + [&](ast::ReturnType &&returnType, ast::StatementList &&statements, + ast::ReturnStatement &&returnStatement) { + std::optional maybeReturnStatement = std::move(returnStatement); + if (returnType == ast::VoidType{}) + maybeReturnStatement.reset(); + + return ast::Function{ + std::move(returnType), std::string(functionName), + scalarParameters, arrayParameters, + statements.takeVector(), std::move(maybeReturnStatement), + }; + }); } std::string @@ -538,30 +613,19 @@ gen::BasicCGenerator::generateTestBench(const ast::Function &kernel) const { ss << "\nint main() {\n"; mlir::raw_indented_ostream os(ss); os.indent(); - for (const auto &[parameter, context] : scalarParameters) { - std::optional constant; - while (!constant) { - constant = generateConstant(context); - } - + for (const ast::ScalarParameter ¶meter : scalarParameters) { os << parameter.getDataType() << ' ' << parameter.getName() << " = " - << *constant << ";\n"; + << getConstantForType(parameter.getDataType()) << ";\n"; } - for (const auto &[parameter, context] : arrayParameters) { + for (const ast::ArrayParameter ¶meter : arrayParameters) { os << parameter.getElementType() << ' ' << parameter.getName() << "[" << parameter.getDimension() << "] = {"; - llvm::interleaveComma( - llvm::seq(0, parameter.getDimension()), os, - [&, &context = context, ¶meter = parameter](auto &&) { - std::optional constant; - while (!constant) { - constant = generateConstant(context); - } - // C++ does not allow implicit casts in array constructors, so we - // must cast the constant explicitly. - os << safeCastAsNeeded(parameter.getElementType(), *constant); - }); + llvm::interleaveComma(llvm::seq(0, parameter.getDimension()), + os, [&, ¶meter = parameter](auto &&) { + os << getConstantForType( + parameter.getElementType()); + }); os << "};\n"; } diff --git a/tools/hls-fuzzer/BasicCGenerator.h b/tools/hls-fuzzer/BasicCGenerator.h index 8a3dd44cf..02567239b 100644 --- a/tools/hls-fuzzer/BasicCGenerator.h +++ b/tools/hls-fuzzer/BasicCGenerator.h @@ -77,9 +77,6 @@ class BasicCGenerator { std::optional parameter; }; - PendingParameter generateFreshScalarParameter(ast::ScalarType datatype, - const OpaqueContext &context); - ast::ReturnStatement generateReturnStatement(const OpaqueContext &constraints); @@ -101,6 +98,8 @@ class BasicCGenerator { std::optional generateCastExpression(const OpaqueContext &constraint, std::size_t depth); + ast::Constant getConstantForType(const ast::ScalarType &scalarType) const; + std::optional generateConstant(const OpaqueContext &constraint, std::size_t depth = 0) const; @@ -126,8 +125,8 @@ class BasicCGenerator { ast::ReturnType generateReturnType(const OpaqueContext &context) const; - std::vector - generateStatementList(const OpaqueContext &context); + ast::StatementList generateStatementList(const OpaqueContext &context, + size_t depth); std::optional generateStatement(const OpaqueContext &context); @@ -135,9 +134,9 @@ class BasicCGenerator { generateArrayAssignmentStatement(const OpaqueContext &context); Randomly &random; - ast::ReturnType returnType; - std::vector> scalarParameters; - std::vector> arrayParameters; + std::optional maybeReturnType; + std::vector scalarParameters; + std::vector arrayParameters; std::size_t varCounter = 0; AbstractTypeSystem &typeSystem; OpaqueContext entryContext; diff --git a/tools/hls-fuzzer/TypeSystem.h b/tools/hls-fuzzer/TypeSystem.h index 5786efd2d..8609184bb 100644 --- a/tools/hls-fuzzer/TypeSystem.h +++ b/tools/hls-fuzzer/TypeSystem.h @@ -1,9 +1,10 @@ #ifndef DYNAMATIC_HLS_FUZZER_TYPE_SYSTEM_GUIDED_GENERATOR #define DYNAMATIC_HLS_FUZZER_TYPE_SYSTEM_GUIDED_GENERATOR +#include "AST.h" #include "Randomly.h" -#include "TypeSystemTraits.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/FunctionExtras.h" #include @@ -84,6 +85,18 @@ class TransferFn { remaining...>::type; }; + // Special case required to still allow parent dependencies when 'ASTNode' + // does not have any subelements. + template + struct CalcCompFn { + // Recursive case. + using type = typename CalcCompFn< + decltype(std::tuple_cat( + std::declval(), + std::declval>())), + 0>::type; + }; + // Terminating end-case template struct CalcCompFn, current> { @@ -272,78 +285,170 @@ using TransferFnArray = /// and 'discard*' methods. class AbstractTypeSystem { protected: - /// Returns an instance of 'TransferFn' which simply forwards the context from - /// the parent to the subelement. + /// Returns an instance of 'TransferFn' which simply forwards the context + /// from the parent to the subelement. template static auto copyFromParent() { - return TransferFn( - [](const OpaqueContext &context) { return context; }); + return copyFrom(); + } + + /// Returns an instance of 'TransferFn' which forwards the context + /// from the given index to the subelement. + template + static auto copyFrom() { + return TransferFn( + [](const OpaqueContext &context, auto &&...) { return context; }); } public: virtual ~AbstractTypeSystem(); - virtual ConclusionOf - checkFunctionOpaque(const OpaqueContext &context) = 0; + virtual TransferFnArray getFunctionTransferFns() { + return { + /*return type=*/copyFromParent(), + /*statement list=*/copyFromParent(), + /*return statement=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual TransferFnArray + getReturnStatementTransferFns() { + return { + /*return value=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } - /// Returns true if the generator should discard this binary expression based - /// on the given input context. + /// Returns true if the generator should discard this binary expression + /// based on the given input context. virtual bool discardBinaryExpressionOpaque(ast::BinaryExpression::Op op, const OpaqueContext &context) = 0; - virtual std::optional> - checkUnaryExpressionOpaque(ast::UnaryExpression::Op op, - const OpaqueContext &context) = 0; - virtual TransferFnArray - getBinaryExpressionContextDependencies(ast::BinaryExpression::Op op) { + getBinaryExpressionTransferFns(ast::BinaryExpression::Op op) { // Default implementation: Simply propagates the context to the subelements. return {/*lhs=*/copyFromParent(), /*rhs=*/copyFromParent(), /*output=*/copyFromParent()}; } - virtual std::optional> - checkVariableOpaque(const OpaqueContext &context) = 0; + virtual bool discardUnaryExpressionOpaque(ast::UnaryExpression::Op op, + const OpaqueContext &context) = 0; - virtual std::optional> - checkCastExpressionOpaque(const OpaqueContext &context) = 0; + virtual TransferFnArray + getUnaryExpressionTransferFns(ast::UnaryExpression::Op op) { + return { + /*operand=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } - virtual std::optional> - checkConditionalExpressionOpaque(const OpaqueContext &context) = 0; + virtual bool discardVariableOpaque(const OpaqueContext &context) = 0; - virtual std::optional> - checkScalarTypeOpaque(const ast::ScalarType &, - const OpaqueContext &context) = 0; + virtual TransferFnArray getVariableTransferFns() { + return { + /*parameter=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual bool discardCastExpressionOpaque(const OpaqueContext &context) = 0; - virtual std::optional> - checkReturnTypeOpaque(const ast::ReturnType &, + virtual TransferFnArray getCastExpressionTransferFns() { + return { + /*target type=*/copyFromParent(), + /*operand=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual bool + discardConditionalExpressionOpaque(const OpaqueContext &context) = 0; + + virtual TransferFnArray + getConditionalExpressionTransferFns() { + // Default implementation: Simply propagates the context to the + // subelements. + return { + /*condition=*/copyFromParent(), + /*true value=*/copyFromParent(), + /*false value=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual bool discardScalarTypeOpaque(const ast::ScalarType &scalarType, + const OpaqueContext &context) = 0; + + virtual bool discardReturnTypeOpaque(const ast::ReturnType &, + const OpaqueContext &context) = 0; + + virtual std::optional + discardConstantOpaque(const ast::Constant &, const OpaqueContext &context) = 0; - virtual std::optional> - checkConstantOpaque(const ast::Constant &, const OpaqueContext &context) = 0; + virtual bool + discardExistingScalarParameterOpaque(const ast::ScalarParameter &, + const OpaqueContext &context) = 0; + + virtual bool + discardFreshScalarParameterOpaque(const OpaqueContext &context) = 0; - virtual std::optional> - checkScalarParameterOpaque(const ast::ScalarParameter &, - const OpaqueContext &context) = 0; + virtual TransferFnArray + getScalarParameterTransferFns() { + return { + /*data type=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } virtual bool discardArrayReadExpressionOpaque(const OpaqueContext &context) = 0; virtual TransferFnArray - getArrayReadExpressionContextDependencies() { + getArrayReadExpressionTransferFns() { return {/*array parameter=*/copyFromParent(), /*index=*/copyFromParent(), /*output=*/copyFromParent()}; } - virtual std::optional> - checkArrayParameterOpaque(const ast::ArrayParameter &, - const OpaqueContext &context) = 0; + virtual bool + discardExistingArrayParameterOpaque(const ast::ArrayParameter &, + const OpaqueContext &context) = 0; + + virtual bool + discardFreshArrayParameterOpaque(const OpaqueContext &context) = 0; + + virtual TransferFnArray getArrayParameterTransferFns() { + return { + /*element type=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual bool + discardArrayAssignmentStatementOpaque(const OpaqueContext &context) = 0; + + virtual TransferFnArray + getArrayAssignmentStatementTransferFns() { + return TransferFnArray{ + /*array parameter=*/copyFromParent(), + /*index=*/copyFromParent(), + /*value=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } + + virtual bool discardStatementListOpaque(const OpaqueContext &context) = 0; - virtual std::optional< - ConclusionOf> - checkArrayAssignmentStatementOpaque(const OpaqueContext &context) = 0; + virtual TransferFnArray getStatementListTransferFns() { + return TransferFnArray{ + /*statement list=*/copyFromParent(), + /*statement=*/copyFromParent(), + /*output=*/copyFromParent(), + }; + } }; /// CRTP-Base class for all implementations of a type system. @@ -356,20 +461,13 @@ class AbstractTypeSystem { /// used when generating sub-elements of an AST-node or 2) rejecting AST-nodes /// entirely based on the current type context. /// -/// There are currently two APIs to achieve this: -/// 1) The transfer functions API -/// 2) the 'check*' API. -/// The latter is considered deprecated and implements a subset of -/// functionality of the transfer functions API. -/// -/// Regardless of API, all type checking is performed under a given context -/// specified as the 'TypingContext' template parameter. -/// Every AST node is initially generated using an input context -/// passed into the 'check*' method or 'discard*' method of the AST node which -/// may discard the AST node. -/// Otherwise, new contexts for the subelements of the AST node can be derived. +/// All type checking is performed under a given context specified as the +/// 'TypingContext' template parameter. Every AST node is initially generated +/// using an input context passed into the 'discard*' method of the AST node +/// which may discard the AST node. Otherwise, new contexts for the subelements +/// of the AST node can be derived. /// -/// The transfer functions API allows specifying how input contexts for AST +/// The transfer functions allow specifying how input contexts for AST /// elements should be calculated. /// Specifically, an instance of 'TransferFn' can specify that it depends on the /// context and AST node of a sibling subelement in addition to, or instead of @@ -380,20 +478,14 @@ class AbstractTypeSystem { /// 'a'. /// The generator uses this knowledge to generate the AST node of 'a' before /// 'i'. -/// 'check*' methods in contrast only implement deriving subelement contexts -/// from the parent input context. -/// They return a tuple of 'TypingContext's for every subelement of 'ASTNode', -/// the so-called conclusion type. /// /// Note: We call it contexts rather than constraints to match literature, and /// as it more generally informs an AST-node generation about the type-system /// state rather than necessarily putting requirements on an AST-node -/// generation. In the future, it'll likely be possible to also output contexts -/// from sub-expressions to parent-expressions. -/// An example of such a context would e.g. be the set of all variables used. +/// generation. /// -/// The logic that should be implemented in the 'check*' methods can be thought -/// of as inversions of the usual type checking rules seen in literature. +/// The logic that should be implemented can be thought of as inversions of the +/// usual type checking rules seen in literature. /// E.g. assuming a type system where the context is a two-state variable that /// requires the expression to either be an integer type or a floating point /// type, then a typing rule for conditional expressions might look as follows: @@ -406,17 +498,11 @@ class AbstractTypeSystem { /// ({integer} |- cond) -> ({A} |- lhs) -> ({A} |- rhs) -> ({A} |- cond ? lhs : /// rhs) /// -/// The corresponding 'checkConditionalExpression' method instead implements: +/// The corresponding 'getConditionalExpressionTransferFns' method +/// instead implements: /// ({A} |- cond ? lhs : rhs) -> ({integer} |- cond) -> ({A} |- lhs) -> ({A} |- -/// rhs) where 'A' is the context passed into the function and the three clauses -/// correspond to the conclusion type of conditional expressions. -/// -/// Check methods for terminals are slightly special: They take as input an -/// already generated terminal node and are always discardable. -/// All check methods have a default implementation that forwards the current -/// constraint to all sub-elements. -/// See the 'TypeSystemTraits' specializations to find the documentation for -/// various AST-Node's conclusion types. +/// rhs) where 'A' is the input context and the three clauses correspond to the +/// input contexts of the sub elements. /// /// The current implementation how a type system is used in the base generator /// has a few constraints: @@ -426,12 +512,7 @@ class AbstractTypeSystem { /// return type. template class TypeSystem : public AbstractTypeSystem { - public: - /// The conclusion type of 'ASTNode' with the given context. - template - using ConclusionOf = ConclusionOf; - template using TransferFn = TransferFn; @@ -443,240 +524,157 @@ class TypeSystem : public AbstractTypeSystem { // since we use CRTP-techniques to call these. They may be but are not // required to be static. - static ConclusionOf - checkFunction(const TypingContext &context) { - return {context, context}; - } - static bool discardBinaryExpression(ast::BinaryExpression::Op, const TypingContext &) { return false; } - static ConclusionOf - checkUnaryExpression(ast::UnaryExpression::Op, const TypingContext &context) { - return {context}; + static bool discardUnaryExpression(ast::UnaryExpression::Op, + const TypingContext &) { + return false; } - static ConclusionOf - checkVariable(const TypingContext &context) { - return {context}; - } + static bool discardVariable(const TypingContext &) { return false; } - static ConclusionOf - checkCastExpression(const TypingContext &context) { - return {context, context}; - } + static bool discardCastExpression(const TypingContext &) { return false; } - static ConclusionOf - checkConditionalExpression(const TypingContext &context) { - return {context, context, context}; + static bool discardConditionalExpression(const TypingContext &) { + return false; } - static ConclusionOf checkScalarType(const ast::ScalarType &, - const TypingContext &) { - return {}; + static bool discardScalarType(const ast::ScalarType &, + const TypingContext &) { + return false; } - std::optional> - checkReturnType(const ast::ReturnType &returnType, - const TypingContext &context) { + bool discardReturnType(const ast::ReturnType &returnType, + const TypingContext &context) { // Default implementation dispatches to 'checkScalarType'. - return llvm::TypeSwitch>>( - returnType) - .Case([](const ast::VoidType *) { - return ConclusionOf{}; - }) - .Case([&](const ast::ScalarType *scalar) - -> std::optional> { - if (std::optional optional = self().checkScalarType(*scalar, context); - !optional) - return std::nullopt; - - return ConclusionOf{}; + return llvm::TypeSwitch(returnType) + .Case([](const ast::VoidType *) { return false; }) + .Case([&](const ast::ScalarType *scalar) { + return self().discardScalarType(*scalar, context); }); } - std::optional> - checkConstant(const ast::Constant &constant, const TypingContext &context) { - if (std::optional optional = - self().checkScalarType(constant.getType(), context); - !optional) + std::optional discardConstant(const ast::Constant &constant, + const TypingContext &context) { + if (self().discardScalarType(constant.getType(), context)) return std::nullopt; return constant; } - std::optional> - checkScalarParameter(const ast::ScalarParameter ¶meter, - const TypingContext &context) { - if (std::optional optional = - self().checkScalarType(parameter.getDataType(), context); - !optional) - return std::nullopt; + bool discardExistingScalarParameter(const ast::ScalarParameter ¶meter, + const TypingContext &context) { + return self().discardScalarType(parameter.getDataType(), context); + } - return context; + static bool discardFreshScalarParameter(const TypingContext &) { + return false; } static bool discardArrayReadExpression(const TypingContext &) { return false; } - std::optional> - checkArrayParameter(const ast::ArrayParameter ¶meter, - const TypingContext &context) { - if (std::optional optional = - self().checkScalarType(parameter.getElementType(), context); - !optional) - return std::nullopt; + bool discardExistingArrayParameter(const ast::ArrayParameter ¶meter, + const TypingContext &context) { + return self().discardScalarType(parameter.getElementType(), context); + } - return context; + static bool discardFreshArrayParameter(const TypingContext &) { + return false; } - static ConclusionOf - checkArrayAssignmentStatement(const TypingContext &context) { - return {context, context, context}; + static bool discardArrayAssignmentStatement(const TypingContext &) { + return false; } + static bool discardStatementList(const TypingContext &) { return false; } + // Implementations of the virtual methods in 'AbstractTypeSystem'. // These are automatically implemented to unbox the 'TypingContext's out of // the opaque contexts, calling the corresponding non-opaque 'check*' method // and boxing the result into an opaque context again. - dynamatic::ConclusionOf - checkFunctionOpaque(const OpaqueContext &context) final { - return convert(self().checkFunction(context.cast())); - } - bool discardBinaryExpressionOpaque(ast::BinaryExpression::Op op, const OpaqueContext &context) final { return self().discardBinaryExpression(op, context.cast()); } - std::optional> - checkUnaryExpressionOpaque(ast::UnaryExpression::Op op, - const OpaqueContext &context) final { - return convert( - self().checkUnaryExpression(op, context.cast())); + bool discardUnaryExpressionOpaque(ast::UnaryExpression::Op op, + const OpaqueContext &context) final { + return self().discardUnaryExpression(op, context.cast()); } - std::optional> - checkVariableOpaque(const OpaqueContext &context) final { - return convert(self().checkVariable(context.cast())); + bool discardVariableOpaque(const OpaqueContext &context) final { + return self().discardVariable(context.cast()); } - std::optional> - checkCastExpressionOpaque(const OpaqueContext &context) final { - return convert(self().checkCastExpression(context.cast())); + bool discardCastExpressionOpaque(const OpaqueContext &context) final { + return self().discardCastExpression(context.cast()); } - std::optional< - dynamatic::ConclusionOf> - checkConditionalExpressionOpaque(const OpaqueContext &context) final { - return convert( - self().checkConditionalExpression(context.cast())); + bool discardConditionalExpressionOpaque(const OpaqueContext &context) final { + return self().discardConditionalExpression(context.cast()); } - std::optional> - checkScalarTypeOpaque(const ast::ScalarType &node, - const OpaqueContext &context) final { - return convert(self().checkScalarType(node, context.cast())); + bool discardScalarTypeOpaque(const ast::ScalarType &node, + const OpaqueContext &context) final { + return self().discardScalarType(node, context.cast()); + } + + bool discardReturnTypeOpaque(const ast::ReturnType &node, + const OpaqueContext &context) final { + return self().discardReturnType(node, context.cast()); } - std::optional> - checkReturnTypeOpaque(const ast::ReturnType &node, + std::optional + discardConstantOpaque(const ast::Constant &node, const OpaqueContext &context) final { - return convert(self().checkReturnType(node, context.cast())); + return self().discardConstant(node, context.cast()); } - std::optional> - checkConstantOpaque(const ast::Constant &node, - const OpaqueContext &context) final { - return convert(self().checkConstant(node, context.cast())); + bool + discardExistingScalarParameterOpaque(const ast::ScalarParameter &node, + const OpaqueContext &context) final { + return self().discardExistingScalarParameter(node, + context.cast()); } - std::optional> - checkScalarParameterOpaque(const ast::ScalarParameter &node, - const OpaqueContext &context) final { - return convert( - self().checkScalarParameter(node, context.cast())); + bool discardFreshScalarParameterOpaque(const OpaqueContext &context) final { + return self().discardFreshScalarParameter(context.cast()); } bool discardArrayReadExpressionOpaque(const OpaqueContext &context) final { return self().discardArrayReadExpression(context.cast()); } - std::optional> - checkArrayParameterOpaque(const ast::ArrayParameter &node, - const OpaqueContext &context) final { - return convert( - self().checkArrayParameter(node, context.cast())); + bool discardExistingArrayParameterOpaque(const ast::ArrayParameter &node, + const OpaqueContext &context) final { + return self().discardExistingArrayParameter(node, + context.cast()); } - std::optional< - dynamatic::ConclusionOf> - checkArrayAssignmentStatementOpaque(const OpaqueContext &context) final { - return convert( - self().checkArrayAssignmentStatement(context.cast())); + bool discardFreshArrayParameterOpaque(const OpaqueContext &context) final { + return self().discardFreshArrayParameter(context.cast()); } -private: - Self &self() { return static_cast(*this); } - - const Self &self() const { return static_cast(*this); } - - static OpaqueContext convert(const TypingContext &context) { - return OpaqueContext(context); - } - - static OpaqueContext convert(TypingContext &&context) { - return OpaqueContext(context); + bool + discardArrayAssignmentStatementOpaque(const OpaqueContext &context) final { + return self().discardArrayAssignmentStatement( + context.cast()); } - template - static auto convert(const T &value) { - return value; + bool discardStatementListOpaque(const OpaqueContext &context) final { + return self().discardStatementList(context.cast()); } - template - static auto convert(std::optional &&value) { - using Ret = decltype(convert(std::move(*value))); - if (!value) - return std::optional{}; - - return std::optional(convert(std::move(*value))); - } - - /// Converts all instances of 'TypingContext' of a tuple-like struct into - /// 'OpaqueContext'. - /// Tuple-like structs are structs that specialize 'std::tuple_size' and - /// implement 'get' methods. - template