Skip to content
18 changes: 12 additions & 6 deletions include/spirv-tools/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class Pass;
struct DescriptorSetAndBinding;
} // namespace opt

enum class SSARewriteMode {
None,
All,
OpaqueOnly,
SpecialTypes,
};

// C++ interface for SPIR-V optimization functionalities. It wraps the context
// (including target environment and the corresponding SPIR-V grammar) and
// provides methods for registering optimization passes and optimizing.
Expand Down Expand Up @@ -125,6 +132,9 @@ class SPIRV_TOOLS_EXPORT Optimizer {
// interface are considered live and are not eliminated.
Optimizer& RegisterLegalizationPasses();
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
Optimizer& RegisterLegalizationPasses(bool preserve_interface,
bool include_loop_unroll,
SSARewriteMode ssa_rewrite_mode);

// Register passes specified in the list of |flags|. Each flag must be a
// string of a form accepted by Optimizer::FlagHasValidForm().
Expand Down Expand Up @@ -645,11 +655,6 @@ Optimizer::PassToken CreateLoopPeelingPass();
// Works best after LICM and local multi store elimination pass.
Optimizer::PassToken CreateLoopUnswitchPass();

// Creates a pass to legalize multidimensional arrays for Vulkan.
// This pass will replace multidimensional arrays of resources with a single
// dimensional array. Combine-access-chains should be run before this pass.
Optimizer::PassToken CreateLegalizeMultidimArrayPass();

// Create global value numbering pass.
// This pass will look for instructions where the same value is computed on all
// paths leading to the instruction. Those instructions are deleted.
Expand Down Expand Up @@ -709,7 +714,8 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
// operations on SSA IDs. This allows SSA optimizers to act on these variables.
// Only variables that are local to the function and of supported types are
// processed (see IsSSATargetVar for details).
Optimizer::PassToken CreateSSARewritePass();
Optimizer::PassToken CreateSSARewritePass(
SSARewriteMode mode = SSARewriteMode::All);

// Create pass to convert relaxed precision instructions to half precision.
// This pass converts as many relaxed float32 arithmetic operations to half as
Expand Down
23 changes: 23 additions & 0 deletions source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ bool LocalSingleStoreElimPass::RewriteLoads(
else
stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx);

const auto get_image_pointer_id = [this](uint32_t value_id) {
Instruction* value_inst = context()->get_def_use_mgr()->GetDef(value_id);
while (value_inst && value_inst->opcode() == spv::Op::OpCopyObject) {
value_id = value_inst->GetSingleWordInOperand(0);
value_inst = context()->get_def_use_mgr()->GetDef(value_id);
}
if (!value_inst || value_inst->opcode() != spv::Op::OpLoad) {
return uint32_t{0};
}
return value_inst->GetSingleWordInOperand(0);
};

*all_rewritten = true;
bool modified = false;
for (Instruction* use : uses) {
Expand All @@ -319,6 +331,17 @@ bool LocalSingleStoreElimPass::RewriteLoads(
context()->KillNamesAndDecorates(use->result_id());
context()->ReplaceAllUsesWith(use->result_id(), stored_id);
context()->KillInst(use);
} else if (use->opcode() == spv::Op::OpImageTexelPointer &&
dominator_analysis->Dominates(store_inst, use)) {
const uint32_t image_ptr_id = get_image_pointer_id(stored_id);
if (image_ptr_id == 0) {
*all_rewritten = false;
continue;
}
modified = true;
context()->ForgetUses(use);
use->SetInOperand(0, {image_ptr_id});
context()->AnalyzeUses(use);
} else {
*all_rewritten = false;
}
Expand Down
40 changes: 27 additions & 13 deletions source/opt/mem_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
}

bool MemPass::IsTargetType(const Instruction* typeInst) const {
if (IsBaseTargetType(typeInst)) return true;
switch (ssa_rewrite_mode_) {
case SSARewriteMode::None:
return false;
case SSARewriteMode::OpaqueOnly:
if (typeInst->IsOpaqueType()) return true;
break;
case SSARewriteMode::SpecialTypes:
if (typeInst->IsOpaqueType()) return true;
switch (typeInst->opcode()) {
case spv::Op::OpTypePointer:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
break;
}
break;
case SSARewriteMode::All:
if (IsBaseTargetType(typeInst)) return true;
break;
}
if (typeInst->opcode() == spv::Op::OpTypeArray) {
if (!IsTargetType(
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
Expand All @@ -72,8 +92,7 @@ bool MemPass::IsTargetType(const Instruction* typeInst) const {

bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
return opcode == spv::Op::OpAccessChain ||
opcode == spv::Op::OpInBoundsAccessChain ||
opcode == spv::Op::OpUntypedAccessChainKHR;
opcode == spv::Op::OpInBoundsAccessChain;
}

bool MemPass::IsPtr(uint32_t ptrId) {
Expand All @@ -89,14 +108,11 @@ bool MemPass::IsPtr(uint32_t ptrId) {
ptrInst = get_def_use_mgr()->GetDef(varId);
}
const spv::Op op = ptrInst->opcode();
if (op == spv::Op::OpVariable || op == spv::Op::OpUntypedVariableKHR ||
IsNonPtrAccessChain(op))
return true;
if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
const uint32_t varTypeId = ptrInst->type_id();
if (varTypeId == 0) return false;
const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
return varTypeInst->opcode() == spv::Op::OpTypePointer ||
varTypeInst->opcode() == spv::Op::OpTypeUntypedPointerKHR;
return varTypeInst->opcode() == spv::Op::OpTypePointer;
}

Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
Expand All @@ -106,13 +122,11 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {

switch (ptrInst->opcode()) {
case spv::Op::OpVariable:
case spv::Op::OpUntypedVariableKHR:
case spv::Op::OpFunctionParameter:
varInst = ptrInst;
break;
case spv::Op::OpAccessChain:
case spv::Op::OpInBoundsAccessChain:
case spv::Op::OpUntypedAccessChainKHR:
case spv::Op::OpPtrAccessChain:
case spv::Op::OpInBoundsPtrAccessChain:
case spv::Op::OpImageTexelPointer:
Expand All @@ -125,8 +139,7 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
break;
}

if (varInst->opcode() == spv::Op::OpVariable ||
varInst->opcode() == spv::Op::OpUntypedVariableKHR) {
if (varInst->opcode() == spv::Op::OpVariable) {
*varId = varInst->result_id();
} else {
*varId = 0;
Expand Down Expand Up @@ -241,7 +254,8 @@ void MemPass::DCEInst(Instruction* inst,
}
}

MemPass::MemPass() {}
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
: ssa_rewrite_mode_(ssa_rewrite_mode) {}

bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
Expand Down
7 changes: 5 additions & 2 deletions source/opt/mem_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <unordered_set>
#include <utility>

#include "spirv-tools/optimizer.hpp"
#include "source/opt/basic_block.h"
#include "source/opt/def_use_manager.h"
#include "source/opt/dominator_analysis.h"
Expand Down Expand Up @@ -68,7 +69,7 @@ class MemPass : public Pass {
void CollectTargetVars(Function* func);

protected:
MemPass();
explicit MemPass(SSARewriteMode ssa_rewrite_mode = SSARewriteMode::All);

// Returns true if |typeInst| is a scalar type
// or a vector or matrix
Expand Down Expand Up @@ -133,7 +134,9 @@ class MemPass : public Pass {
// Cache of verified non-target vars
std::unordered_set<uint32_t> seen_non_target_vars_;

private:
private:
SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All;

// Return true if all uses of |varId| are only through supported reference
// operations ie. loads and store. Also cache in supported_ref_vars_.
// TODO(dnovillo): This function is replicated in other passes and it's
Expand Down
Loading