Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tools/hls-fuzzer/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,11 @@ llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os,
<< parameter.getDimension() << ']';
}

llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os,
const ReturnType &returnType) {
return os << returnType.variant;
}

llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os,
const Function &function) {
os << function.returnType << ' ' << function.name << '(';
Expand All @@ -391,10 +396,10 @@ llvm::raw_ostream &ast::operator<<(llvm::raw_ostream &os,
mlir::raw_indented_ostream indentedOstream(os);
indentedOstream.indent();
for (auto &iter : function.statements)
indentedOstream << iter;
indentedOstream << iter << '\n';
if (function.returnStatement)
indentedOstream << *function.returnStatement;
indentedOstream << *function.returnStatement << '\n';

os << "\n}\n";
os << "}\n";
return os;
}
124 changes: 122 additions & 2 deletions tools/hls-fuzzer/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class ScalarType {
template <typename From>
friend struct llvm::simplify_type;

using SubElements = std::tuple<>;

private:
std::shared_ptr<const Variant> datatype;
};
Expand Down Expand Up @@ -234,16 +236,22 @@ struct Constant {
return PrimitiveType::Double;
});
}

using SubElements = std::tuple<>;
};

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Constant &constant);

class ScalarParameter;

/// AST-node representing a reference to a variable in C.
struct Variable {
const ScalarType datatype;
const std::string name;

const ScalarType &getType() const { return datatype; }

using SubElements = std::tuple<ScalarParameter>;
};

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variable &variable);
Expand Down Expand Up @@ -324,6 +332,9 @@ class BinaryExpression {

using SubElements = std::tuple<Expression, Expression>;

constexpr static std::size_t LHS = 0;
constexpr static std::size_t RHS = 1;

private:
Expression lhs;
Op op;
Expand All @@ -344,6 +355,11 @@ class CastExpression {
/// Returns the type of this expression, i.e., the type being cast to.
const ScalarType &getType() const { return targetType; }

using SubElements = std::tuple<ScalarType, Expression>;

constexpr static std::size_t TARGET_TYPE = 0;
constexpr static std::size_t OPERAND = 1;

private:
ScalarType targetType;
Expression expression;
Expand Down Expand Up @@ -374,6 +390,8 @@ class UnaryExpression {

static bool isLegalOperandType(Op op, const ScalarType &type);

using SubElements = std::tuple<Expression>;

private:
Op op;
Expression expression;
Expand All @@ -398,6 +416,11 @@ class ConditionalExpression {

ScalarType getType() const;

using SubElements = std::tuple<Expression, Expression, Expression>;
constexpr static std::size_t CONDITION = 0;
constexpr static std::size_t TRUE_VAL = 1;
constexpr static std::size_t FALSE_VAL = 2;

private:
Expression condition;
Expression trueVal;
Expand Down Expand Up @@ -452,6 +475,8 @@ class ReturnStatement {

const Expression &getReturnValue() const { return returnValue; }

using SubElements = std::tuple<Expression>;

private:
Expression returnValue;
};
Expand All @@ -478,6 +503,11 @@ class ArrayAssignmentStatement {
/// Returns the value that will be assigned to the element.
const Expression &getValueExpression() const { return valueExpression; }

using SubElements = std::tuple<ArrayParameter, Expression, Expression>;
constexpr static std::size_t ARRAY = 0;
constexpr static std::size_t INDEX = 1;
constexpr static std::size_t VALUE = 2;

private:
std::string arrayParameter;
Expression indexingExpression;
Expand Down Expand Up @@ -507,6 +537,40 @@ class Statement {
std::shared_ptr<const Variant> statement;
};

/// Class representing a list of statements in a body.
class StatementList {
public:
StatementList() = default;

explicit StatementList(std::vector<Statement> statements)
: statements(std::move(statements)) {}

/// Returns the number of statements.
std::size_t size() const { return statements.size(); }

const Statement &operator[](std::size_t index) const {
return statements[index];
}

auto begin() { return statements.begin(); }

auto begin() const { return statements.begin(); }

auto end() { return statements.end(); }

auto end() const { return statements.end(); }

std::vector<Statement> takeVector() { return std::move(statements); }

// Recursive statement list representation.
// The definition is left recursive, meaning the statement is always the tail
// statement after the list.
using SubElements = std::tuple<StatementList, Statement>;

private:
std::vector<Statement> statements;
};

/// AST-Node representing a scalar function parameter in C.
class ScalarParameter {
public:
Expand All @@ -519,6 +583,8 @@ class ScalarParameter {

const ScalarType &getDataType() const { return dataType; }

using SubElements = std::tuple<ScalarType>;

private:
ScalarType dataType;
std::string name;
Expand All @@ -545,6 +611,8 @@ class ArrayParameter {

std::size_t getDimension() const { return dimension; }

using SubElements = std::tuple<ScalarType>;

private:
ScalarType dataType;
std::string name;
Expand All @@ -555,13 +623,50 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ArrayParameter &parameter);

/// Tag type representing the 'void' type from C.
struct VoidType {};
struct VoidType {
friend bool operator==(const VoidType &lhs, const VoidType &rhs) {
return true;
}

friend bool operator!=(const VoidType &lhs, const VoidType &rhs) {
return !(lhs == rhs);
}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const VoidType &) {
return os << "void";
}

using ReturnType = std::variant<VoidType, ScalarType>;
class ReturnType {
using Variant = std::variant<VoidType, ScalarType>;

public:
ReturnType() = default;

template <class T, std::enable_if_t<std::is_constructible_v<Variant, T> &&
!std::is_same_v<std::decay_t<Variant>, T>>
* = nullptr>
/*implicit*/ ReturnType(T &&arg) : variant(std::forward<T>(arg)) {}

friend bool operator==(const ReturnType &lhs, const ReturnType &rhs) {
return lhs.variant == rhs.variant;
}

friend bool operator!=(const ReturnType &lhs, const ReturnType &rhs) {
return !(lhs == rhs);
}

template <typename From>
friend struct llvm::simplify_type;

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ReturnType &returnType);

using SubElements = std::tuple<>;

private:
Variant variant;
};

/// AST-Node representing a function in C.
/// Functions are currently limited to just a return statement.
Expand All @@ -576,6 +681,11 @@ struct Function {
/// The return statement at the end of a function iff it does not have a void
/// return type.
const std::optional<ReturnStatement> returnStatement;

using SubElements = std::tuple<ReturnType, StatementList, ReturnStatement>;
constexpr static std::size_t RETURN_TYPE = 0;
constexpr static std::size_t STATEMENTS = 1;
constexpr static std::size_t RETURN_STATEMENT = 2;
};

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Function &function);
Expand Down Expand Up @@ -638,4 +748,14 @@ struct llvm::simplify_type<dynamatic::ast::ScalarType> {
}
};

template <>
struct llvm::simplify_type<dynamatic::ast::ReturnType> {
using SimpleType = const dynamatic::ast::ReturnType::Variant;

static SimpleType &
getSimplifiedValue(const dynamatic::ast::ReturnType &datatype) {
return datatype.variant;
}
};

#endif
Loading
Loading