Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tools/hls-fuzzer/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class BinaryExpression {
/// operand for 'op'.
static bool isLegalOperandType(Op op, const ScalarType &datatype);

using SubElements = std::tuple<Expression, Expression>;

private:
Expression lhs;
Op op;
Expand Down Expand Up @@ -405,6 +407,8 @@ class ConditionalExpression {
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ConditionalExpression &ternaryExpression);

class ArrayParameter;

/// Expression representing reading and indexing into an array.
/// Only array parameters and one-dimensional arrays are currently supported.
class ArrayReadExpression {
Expand All @@ -426,6 +430,11 @@ class ArrayReadExpression {
/// element type of the array.
const ScalarType &getType() const { return dataType; }

using SubElements = std::tuple<ArrayParameter, Expression>;

constexpr static std::size_t ARRAY_PARAMETER = 0;
constexpr static std::size_t INDEX = 1;

private:
ScalarType dataType;
std::string arrayParameter;
Expand Down
219 changes: 115 additions & 104 deletions tools/hls-fuzzer/BasicCGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,79 +123,88 @@ std::optional<ast::Expression>
gen::BasicCGenerator::generateBinaryExpression(ast::BinaryExpression::Op op,
const OpaqueContext &context,
std::size_t depth) {
auto conclusion = typeSystem.checkBinaryExpressionOpaque(op, context);
if (!conclusion)
if (typeSystem.discardBinaryExpressionOpaque(op, context))
return std::nullopt;
auto [lhsCons, rhsCons] = *conclusion;

ast::Expression lhs = generateExpression(lhsCons, depth + 1);
ast::Expression rhs = generateExpression(rhsCons, depth + 1);

// Perform explicit casts to a legal operand type if neither of the
// expressions are legal for the given operation.
// This would e.g. cast 'double's that are meant to be applied to '&' to a
// random type that can be legally used with '&'.
if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) ||
!ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) {

std::optional<ast::ScalarType> scalarType = generateScalarType(
context, /*toExclude=*/[&](const ast::ScalarType &value) {
return !ast::BinaryExpression::isLegalOperandType(op, value);
});
if (!scalarType)
return std::nullopt;

lhs = safeCastAsNeeded(*scalarType, std::move(lhs));
rhs = safeCastAsNeeded(*scalarType, std::move(rhs));
}

switch (op) {
case ast::BinaryExpression::ShiftLeft:
case ast::BinaryExpression::ShiftRight: {
ast::ScalarType datatype = lhs.getType();
// Restrict the right expression to be in range of the bitwidth.
rhs = ast::BinaryExpression{
std::move(rhs), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<uint32_t>(datatype.getBitwidth() - 1)}};

// If the left-hand side is a signed integer, make sure the value is at
// least 0.
// Performing a left-shift on a negative value in C is undefined behavior.
if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned())
lhs = generateMinExpression(std::move(lhs),
ast::Constant{static_cast<uint32_t>(0)});
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
case ast::BinaryExpression::Plus:
case ast::BinaryExpression::Minus:
case ast::BinaryExpression::Mul: {
ast::ScalarType lhsType = lhs.getType();
ast::ScalarType rhsType = rhs.getType();
if ((lhsType == ast::PrimitiveType::Int32 &&
lhsType.getBitwidth() > rhsType.getBitwidth()) ||
(rhsType == ast::PrimitiveType::Int32 &&
rhsType.getBitwidth() > lhsType.getBitwidth())) {
// Promote integers where one operand is an 'int32_t' to 'uint32_t' to
// avoid undefined behavior on overflow.
lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs));
rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs));
}
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
// case ast::BinaryExpression::Division:
break;
case ast::BinaryExpression::BitAnd:
case ast::BinaryExpression::BitOr:
case ast::BinaryExpression::BitXor:
case ast::BinaryExpression::Greater:
case ast::BinaryExpression::GreaterEqual:
case ast::BinaryExpression::Less:
case ast::BinaryExpression::LessEqual:
case ast::BinaryExpression::Equal:
case ast::BinaryExpression::NotEqual:
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
llvm_unreachable("all enum cases handled");
return generateWithDependencies<ast::BinaryExpression>(
context, typeSystem.getBinaryExpressionContextDependencies(op),
/*lhs=*/
[&](const OpaqueContext &context) -> ast::Expression {
return generateExpression(context, depth + 1);
},
/*rhs=*/
[&](const OpaqueContext &context) -> ast::Expression {
return generateExpression(context, depth + 1);
},
/*constructor=*/
[&](ast::Expression &&lhs,
ast::Expression &&rhs) -> std::optional<ast::BinaryExpression> {
// Perform explicit casts to a legal operand type if neither of the
// expressions are legal for the given operation.
// This would e.g. cast 'double's that are meant to be applied to '&' to
// a random type that can be legally used with '&'.
if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) ||
!ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) {

std::optional<ast::ScalarType> scalarType = generateScalarType(
context, /*toExclude=*/[&](const ast::ScalarType &value) {
return !ast::BinaryExpression::isLegalOperandType(op, value);
});
if (!scalarType)
return std::nullopt;

lhs = safeCastAsNeeded(*scalarType, std::move(lhs));
rhs = safeCastAsNeeded(*scalarType, std::move(rhs));
}

switch (op) {
case ast::BinaryExpression::ShiftLeft:
case ast::BinaryExpression::ShiftRight: {
ast::ScalarType datatype = lhs.getType();
// Restrict the right expression to be in range of the bitwidth.
rhs = ast::BinaryExpression{
std::move(rhs), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<uint32_t>(datatype.getBitwidth() - 1)}};

// If the left-hand side is a signed integer, make sure the value is
// at least 0. Performing a left-shift on a negative value in C is
// undefined behavior.
if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned())
lhs = generateMinExpression(
std::move(lhs), ast::Constant{static_cast<uint32_t>(0)});
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
case ast::BinaryExpression::Plus:
case ast::BinaryExpression::Minus:
case ast::BinaryExpression::Mul: {
ast::ScalarType lhsType = lhs.getType();
ast::ScalarType rhsType = rhs.getType();
if ((lhsType == ast::PrimitiveType::Int32 &&
lhsType.getBitwidth() > rhsType.getBitwidth()) ||
(rhsType == ast::PrimitiveType::Int32 &&
rhsType.getBitwidth() > lhsType.getBitwidth())) {
// Promote integers where one operand is an 'int32_t' to 'uint32_t'
// to avoid undefined behavior on overflow.
lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs));
rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs));
}
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
// case ast::BinaryExpression::Division:
break;
case ast::BinaryExpression::BitAnd:
case ast::BinaryExpression::BitOr:
case ast::BinaryExpression::BitXor:
case ast::BinaryExpression::Greater:
case ast::BinaryExpression::GreaterEqual:
case ast::BinaryExpression::Less:
case ast::BinaryExpression::LessEqual:
case ast::BinaryExpression::Equal:
case ast::BinaryExpression::NotEqual:
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
llvm_unreachable("all enum cases handled");
});
Comment on lines +132 to +207
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split

auto contextTransferFn = ...;
auto lhsGenFn = ...;
auto rhsGenFn = ...'
auto expressionGenFn = ...;

return generateWithDependencies<ast::BinaryExpression>(context, contextTransferFn, lhsGenFn, rhsGenFn, expressionGenFn);


Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #874 (comment)

I added argument comments instead.

}

std::optional<ast::Expression>
Expand Down Expand Up @@ -304,41 +313,43 @@ gen::BasicCGenerator::generateConstant(const OpaqueContext &context,
std::optional<ast::ArrayReadExpression>
gen::BasicCGenerator::generateArrayReadExpression(const OpaqueContext &context,
std::size_t depth) {
auto conclusion = typeSystem.checkArrayReadExpressionOpaque(context);
if (!conclusion)
if (typeSystem.discardArrayReadExpressionOpaque(context))
return std::nullopt;
auto [paramConc, indexConc] = *conclusion;

// Construct a safe indexing expression from an array parameter.
auto genWrappedArrayReadFromParam = [&, &indexConc = indexConc](
const ast::ArrayParameter &param) {
ast::ScalarType elementType = param.getElementType();
std::size_t mask = param.getDimension() - 1;
std::string name = param.getName().str();
// Generate an indexing expression.
// Has to be an integer.
ast::Expression index = safeCastAsNeeded(
ast::PrimitiveType::UInt32, generateExpression(indexConc, depth + 1));

// Bitmask the index to be in range of the array! We use this to avoid
// undefined behavior in our programs. In the future we could also add
// mechanisms (type systems, or whatever), that restrict expressions to
// safe in-range expressions.
//
// Note: We can use a bitmask here since array parameters that we generate
// are all powers-of-2. We do so since the modulo operator is currently
// unsupported in dynamatic.
return ast::ArrayReadExpression{
std::move(elementType), name,
ast::BinaryExpression{std::move(index), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<std::uint32_t>(mask)}}};
};

std::optional<ast::ArrayParameter> arrayParameter =
generateArrayParameter(paramConc);
if (!arrayParameter)
return std::nullopt;
return genWrappedArrayReadFromParam(*arrayParameter);
return generateWithDependencies<ast::ArrayReadExpression>(
context, typeSystem.getArrayReadExpressionContextDependencies(),
/*array parameter=*/
[&](const OpaqueContext &context) -> std::optional<ast::ArrayParameter> {
return generateArrayParameter(context);
},
/*index=*/
[&](const OpaqueContext &context) -> std::optional<ast::Expression> {
return generateExpression(context, depth + 1);
},
/*constructor=*/
[&](ast::ArrayParameter &&param, ast::Expression &&expression) {
ast::ScalarType elementType = param.getElementType();
std::size_t mask = param.getDimension() - 1;
std::string name = param.getName().str();
// Generate an indexing expression.
// Has to be an integer.
ast::Expression index =
safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(expression));

// Bitmask the index to be in range of the array! We use this to avoid
// undefined behavior in our programs. In the future we could also add
// mechanisms (type systems, or whatever), that restrict expressions to
// safe in-range expressions.
//
// Note: We can use a bitmask here since array parameters that we
// generate are all powers-of-2. We do so since the modulo operator is
// currently unsupported in dynamatic.
return ast::ArrayReadExpression{
std::move(elementType), name,
ast::BinaryExpression{
std::move(index), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<std::uint32_t>(mask)}}};
});
}

std::optional<ast::ArrayParameter>
Expand Down
Loading
Loading