Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tools/hls-fuzzer/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class BinaryExpression {
/// operand for 'op'.
static bool isLegalOperandType(Op op, const ScalarType &datatype);

using SubElements = std::tuple<Expression, Expression>;

private:
Expression lhs;
Op op;
Expand Down Expand Up @@ -405,6 +407,8 @@ class ConditionalExpression {
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ConditionalExpression &ternaryExpression);

class ArrayParameter;

/// Expression representing reading and indexing into an array.
/// Only array parameters and one-dimensional arrays are currently supported.
class ArrayReadExpression {
Expand All @@ -426,6 +430,11 @@ class ArrayReadExpression {
/// element type of the array.
const ScalarType &getType() const { return dataType; }

using SubElements = std::tuple<ArrayParameter, Expression>;

constexpr static std::size_t ARRAY_PARAMETER = 0;
constexpr static std::size_t INDEX = 1;

private:
ScalarType dataType;
std::string arrayParameter;
Expand Down
219 changes: 115 additions & 104 deletions tools/hls-fuzzer/BasicCGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,79 +123,88 @@ std::optional<ast::Expression>
gen::BasicCGenerator::generateBinaryExpression(ast::BinaryExpression::Op op,
const OpaqueContext &context,
std::size_t depth) {
auto conclusion = typeSystem.checkBinaryExpressionOpaque(op, context);
if (!conclusion)
if (typeSystem.discardBinaryExpressionOpaque(op, context))
return std::nullopt;
auto [lhsCons, rhsCons] = *conclusion;

ast::Expression lhs = generateExpression(lhsCons, depth + 1);
ast::Expression rhs = generateExpression(rhsCons, depth + 1);

// Perform explicit casts to a legal operand type if neither of the
// expressions are legal for the given operation.
// This would e.g. cast 'double's that are meant to be applied to '&' to a
// random type that can be legally used with '&'.
if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) ||
!ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) {

std::optional<ast::ScalarType> scalarType = generateScalarType(
context, /*toExclude=*/[&](const ast::ScalarType &value) {
return !ast::BinaryExpression::isLegalOperandType(op, value);
});
if (!scalarType)
return std::nullopt;

lhs = safeCastAsNeeded(*scalarType, std::move(lhs));
rhs = safeCastAsNeeded(*scalarType, std::move(rhs));
}

switch (op) {
case ast::BinaryExpression::ShiftLeft:
case ast::BinaryExpression::ShiftRight: {
ast::ScalarType datatype = lhs.getType();
// Restrict the right expression to be in range of the bitwidth.
rhs = ast::BinaryExpression{
std::move(rhs), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<uint32_t>(datatype.getBitwidth() - 1)}};

// If the left-hand side is a signed integer, make sure the value is at
// least 0.
// Performing a left-shift on a negative value in C is undefined behavior.
if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned())
lhs = generateMinExpression(std::move(lhs),
ast::Constant{static_cast<uint32_t>(0)});
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
case ast::BinaryExpression::Plus:
case ast::BinaryExpression::Minus:
case ast::BinaryExpression::Mul: {
ast::ScalarType lhsType = lhs.getType();
ast::ScalarType rhsType = rhs.getType();
if ((lhsType == ast::PrimitiveType::Int32 &&
lhsType.getBitwidth() > rhsType.getBitwidth()) ||
(rhsType == ast::PrimitiveType::Int32 &&
rhsType.getBitwidth() > lhsType.getBitwidth())) {
// Promote integers where one operand is an 'int32_t' to 'uint32_t' to
// avoid undefined behavior on overflow.
lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs));
rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs));
}
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
// case ast::BinaryExpression::Division:
break;
case ast::BinaryExpression::BitAnd:
case ast::BinaryExpression::BitOr:
case ast::BinaryExpression::BitXor:
case ast::BinaryExpression::Greater:
case ast::BinaryExpression::GreaterEqual:
case ast::BinaryExpression::Less:
case ast::BinaryExpression::LessEqual:
case ast::BinaryExpression::Equal:
case ast::BinaryExpression::NotEqual:
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
llvm_unreachable("all enum cases handled");
return generateWithDependencies<ast::BinaryExpression>(
context, typeSystem.getBinaryExpressionContextDependencies(op),
/*lhs=*/
[&](const OpaqueContext &context) -> ast::Expression {
return generateExpression(context, depth + 1);
},
/*rhs=*/
[&](const OpaqueContext &context) -> ast::Expression {
return generateExpression(context, depth + 1);
},
/*constructor=*/
[&](ast::Expression &&lhs,
ast::Expression &&rhs) -> std::optional<ast::BinaryExpression> {
// Perform explicit casts to a legal operand type if neither of the
// expressions are legal for the given operation.
// This would e.g. cast 'double's that are meant to be applied to '&' to
// a random type that can be legally used with '&'.
if (!ast::BinaryExpression::isLegalOperandType(op, lhs.getType()) ||
!ast::BinaryExpression::isLegalOperandType(op, rhs.getType())) {

std::optional<ast::ScalarType> scalarType = generateScalarType(
context, /*toExclude=*/[&](const ast::ScalarType &value) {
return !ast::BinaryExpression::isLegalOperandType(op, value);
});
if (!scalarType)
return std::nullopt;

lhs = safeCastAsNeeded(*scalarType, std::move(lhs));
rhs = safeCastAsNeeded(*scalarType, std::move(rhs));
}

switch (op) {
case ast::BinaryExpression::ShiftLeft:
case ast::BinaryExpression::ShiftRight: {
ast::ScalarType datatype = lhs.getType();
// Restrict the right expression to be in range of the bitwidth.
rhs = ast::BinaryExpression{
std::move(rhs), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<uint32_t>(datatype.getBitwidth() - 1)}};

// If the left-hand side is a signed integer, make sure the value is
// at least 0. Performing a left-shift on a negative value in C is
// undefined behavior.
if (op == ast::BinaryExpression::ShiftLeft && datatype.isSigned())
lhs = generateMinExpression(
std::move(lhs), ast::Constant{static_cast<uint32_t>(0)});
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
case ast::BinaryExpression::Plus:
case ast::BinaryExpression::Minus:
case ast::BinaryExpression::Mul: {
ast::ScalarType lhsType = lhs.getType();
ast::ScalarType rhsType = rhs.getType();
if ((lhsType == ast::PrimitiveType::Int32 &&
lhsType.getBitwidth() > rhsType.getBitwidth()) ||
(rhsType == ast::PrimitiveType::Int32 &&
rhsType.getBitwidth() > lhsType.getBitwidth())) {
// Promote integers where one operand is an 'int32_t' to 'uint32_t'
// to avoid undefined behavior on overflow.
lhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(lhs));
rhs = safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(rhs));
}
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
// case ast::BinaryExpression::Division:
break;
case ast::BinaryExpression::BitAnd:
case ast::BinaryExpression::BitOr:
case ast::BinaryExpression::BitXor:
case ast::BinaryExpression::Greater:
case ast::BinaryExpression::GreaterEqual:
case ast::BinaryExpression::Less:
case ast::BinaryExpression::LessEqual:
case ast::BinaryExpression::Equal:
case ast::BinaryExpression::NotEqual:
return ast::BinaryExpression{std::move(lhs), op, std::move(rhs)};
}
llvm_unreachable("all enum cases handled");
});
Comment on lines +132 to +207
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split

auto contextTransferFn = ...;
auto lhsGenFn = ...;
auto rhsGenFn = ...'
auto expressionGenFn = ...;

return generateWithDependencies<ast::BinaryExpression>(context, contextTransferFn, lhsGenFn, rhsGenFn, expressionGenFn);


Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #874 (comment)

I added argument comments instead.

}

std::optional<ast::Expression>
Expand Down Expand Up @@ -304,41 +313,43 @@ gen::BasicCGenerator::generateConstant(const OpaqueContext &context,
std::optional<ast::ArrayReadExpression>
gen::BasicCGenerator::generateArrayReadExpression(const OpaqueContext &context,
std::size_t depth) {
auto conclusion = typeSystem.checkArrayReadExpressionOpaque(context);
if (!conclusion)
if (typeSystem.discardArrayReadExpressionOpaque(context))
return std::nullopt;
auto [paramConc, indexConc] = *conclusion;

// Construct a safe indexing expression from an array parameter.
auto genWrappedArrayReadFromParam = [&, &indexConc = indexConc](
const ast::ArrayParameter &param) {
ast::ScalarType elementType = param.getElementType();
std::size_t mask = param.getDimension() - 1;
std::string name = param.getName().str();
// Generate an indexing expression.
// Has to be an integer.
ast::Expression index = safeCastAsNeeded(
ast::PrimitiveType::UInt32, generateExpression(indexConc, depth + 1));

// Bitmask the index to be in range of the array! We use this to avoid
// undefined behavior in our programs. In the future we could also add
// mechanisms (type systems, or whatever), that restrict expressions to
// safe in-range expressions.
//
// Note: We can use a bitmask here since array parameters that we generate
// are all powers-of-2. We do so since the modulo operator is currently
// unsupported in dynamatic.
return ast::ArrayReadExpression{
std::move(elementType), name,
ast::BinaryExpression{std::move(index), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<std::uint32_t>(mask)}}};
};

std::optional<ast::ArrayParameter> arrayParameter =
generateArrayParameter(paramConc);
if (!arrayParameter)
return std::nullopt;
return genWrappedArrayReadFromParam(*arrayParameter);
return generateWithDependencies<ast::ArrayReadExpression>(
context, typeSystem.getArrayReadExpressionContextDependencies(),
/*array parameter=*/
[&](const OpaqueContext &context) -> std::optional<ast::ArrayParameter> {
return generateArrayParameter(context);
},
/*index=*/
[&](const OpaqueContext &context) -> std::optional<ast::Expression> {
return generateExpression(context, depth + 1);
},
/*constructor=*/
[&](ast::ArrayParameter &&param, ast::Expression &&expression) {
ast::ScalarType elementType = param.getElementType();
std::size_t mask = param.getDimension() - 1;
std::string name = param.getName().str();
// Generate an indexing expression.
// Has to be an integer.
ast::Expression index =
safeCastAsNeeded(ast::PrimitiveType::UInt32, std::move(expression));

// Bitmask the index to be in range of the array! We use this to avoid
// undefined behavior in our programs. In the future we could also add
// mechanisms (type systems, or whatever), that restrict expressions to
// safe in-range expressions.
//
// Note: We can use a bitmask here since array parameters that we
// generate are all powers-of-2. We do so since the modulo operator is
// currently unsupported in dynamatic.
return ast::ArrayReadExpression{
std::move(elementType), name,
ast::BinaryExpression{
std::move(index), ast::BinaryExpression::BitAnd,
ast::Constant{static_cast<std::uint32_t>(mask)}}};
});
}

std::optional<ast::ArrayParameter>
Expand Down
Loading
Loading