Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions include/circt/Dialect/Synth/Transforms/CutRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <limits>
#include <memory>
#include <optional>
#include <utility>
Expand Down Expand Up @@ -131,6 +132,9 @@ struct LogicNetworkGate {
/// inversion bit is encoded in each edge.
Signal edges[3];

/// Number of uses outside the logic network.
unsigned externalUseCount = 0;

LogicNetworkGate() : opAndKind(nullptr, Constant), edges{} {}
LogicNetworkGate(Operation *op, Kind kind,
llvm::ArrayRef<Signal> operands = {})
Expand Down Expand Up @@ -171,11 +175,15 @@ struct LogicNetworkGate {
return k == And2 || k == Xor2 || k == Maj3 || k == Identity || k == Choice;
}

/// Check if this should always be a cut input (PI or constant).
bool isAlwaysCutInput() const {
/// Check if this gate is a cut leaf (PI or constant).
bool isCutLeaf() const {
Kind k = getKind();
return k == PrimaryInput || k == Constant;
}

bool isPrimaryOutput() const { return externalUseCount != 0; }

unsigned getExternalUseCount() const { return externalUseCount; }
};

/// Flat logic network representation for efficient cut enumeration.
Expand Down Expand Up @@ -258,6 +266,16 @@ class LogicNetwork {
/// Get the total number of nodes in the network.
size_t size() const { return gates.size(); }

/// Check if a node is observed outside the logic network.
bool isPrimaryOutput(uint32_t index) const {
return gates[index].isPrimaryOutput();
}

/// Get the number of uses outside the logic network.
unsigned getExternalUseCount(uint32_t index) const {
return gates[index].getExternalUseCount();
}

/// Add a primary input to the network.
uint32_t addPrimaryInput(Value value);

Expand All @@ -279,6 +297,8 @@ class LogicNetwork {
void clear();

private:
void recordExternalUse(uint32_t index) { ++gates[index].externalUseCount; }

/// Map from MLIR Value to network index.
llvm::DenseMap<Value, uint32_t> valueToIndex;

Expand Down Expand Up @@ -349,27 +369,49 @@ class MatchedPattern {
private:
const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
SmallVector<DelayType, 1>
arrivalTimes; ///< Arrival times of outputs from this pattern
double area = 0.0; ///< Area cost of this pattern
arrivalTimes; ///< Arrival times of outputs from this pattern
/// Saved match data we reuse during area-flow reselection.
MatchResult matchResult;
SmallVector<unsigned, 6> patternInputToCutInput;

public:
/// Default constructor creates an invalid matched pattern.
MatchedPattern() = default;

/// Constructor for a valid matched pattern.
MatchedPattern(const CutRewritePattern *pattern,
SmallVector<DelayType, 1> arrivalTimes, double area)
: pattern(pattern), arrivalTimes(std::move(arrivalTimes)), area(area) {}
SmallVector<DelayType, 1> arrivalTimes,
MatchResult matchResult,
ArrayRef<unsigned> patternInputToCutInput)
: pattern(pattern), arrivalTimes(std::move(arrivalTimes)),
matchResult(std::move(matchResult)),
patternInputToCutInput(patternInputToCutInput.begin(),
patternInputToCutInput.end()) {}

/// Get the arrival time of signals through this pattern.
DelayType getArrivalTime(unsigned outputIndex) const;
ArrayRef<DelayType> getArrivalTimes() const;
DelayType getWorstOutputArrivalTime() const;

/// Get the library pattern that was matched.
const CutRewritePattern *getPattern() const;

/// Get the area cost of using this pattern.
double getArea() const;

/// Get the per-input delays used when scoring this match.
ArrayRef<DelayType> getDelays() const;

/// Get the cached match payload used to rebuild this match.
const MatchResult &getMatchResult() const { return matchResult; }

/// Get the mapping from pattern input indices to cut input indices.
ArrayRef<unsigned> getInputPermutation() const {
return patternInputToCutInput;
}

/// Get the delay for a cut input after accounting for input permutation.
DelayType getDelayForCutInput(unsigned cutInputIndex) const;
};

/// Represents a cut in the combinational logic network.
Expand Down Expand Up @@ -529,6 +571,15 @@ class CutSet {
bool isFrozen = false; ///< Whether cut set is finalized

public:
/// Latest time this node is allowed to arrive.
DelayType requiredTime = std::numeric_limits<DelayType>::max();

/// Arrival time of the currently selected cut.
DelayType bestArrivalTime = 0;

/// Current area-flow score for the selected cut.
double areaFlow = 0.0;

/// Check if this cut set has a valid matched pattern.
bool isMatched() const { return bestCut; }

Expand All @@ -551,6 +602,9 @@ class CutSet {

/// Get read-only access to all cuts in this set.
ArrayRef<Cut *> getCuts() const;

/// Replace the currently selected cut during area recovery.
void setBestCut(Cut *cut) { bestCut = cut; }
};

/// Configuration options for the cut-based rewriting algorithm.
Expand Down Expand Up @@ -658,6 +712,15 @@ class CutEnumerator {

void dump() const;

/// Compute required times from the current timing-feasible seed mapping.
void computeRequiredTimes();

/// Re-select cuts using area-flow while preserving required times.
void reselectCutsForAreaFlow();

/// Re-select cuts using exact-area deref/ref while preserving required times.
void reselectCutsForExactArea();

/// Get cut sets (indexed by LogicNetwork index).
const llvm::DenseMap<uint32_t, CutSet *> &getCutSets() const {
return cutSets;
Expand Down
Loading
Loading