Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 3 additions & 40 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1795,24 +1795,8 @@ void CodeGen_ARM::visit(const Store *op) {
if (target_vscale() > 0) {
const IntImm *stride = ramp ? ramp->stride.as<IntImm>() : nullptr;
if (stride && stride->value == 1) {
// Basically we can deal with vanilla codegen,
// but to avoid LLVM error, process with the multiple of natural_lanes
const int natural_lanes = natural_vector_size(op->value.type());
if (ramp->lanes % natural_lanes && !emit_atomic_stores) {
int aligned_lanes = align_up(ramp->lanes, natural_lanes);
// Use predicate to prevent overrun
Expr vpred;
if (is_predicated_store) {
vpred = Shuffle::make_concat({op->predicate, const_false(aligned_lanes - ramp->lanes)});
} else {
vpred = make_vector_predicate_1s_0s(ramp->lanes, aligned_lanes - ramp->lanes);
}
auto aligned_index = Ramp::make(ramp->base, stride, aligned_lanes);
Expr padding = make_zero(op->value.type().with_lanes(aligned_lanes - ramp->lanes));
Expr aligned_value = Shuffle::make_concat({op->value, padding});
codegen(Store::make(op->name, aligned_value, aligned_index, op->param, vpred, op->alignment));
return;
}
CodeGen_CPU::visit(op);
return;
} else if (op->index.type().is_vector()) {
// Scatter
Type elt = op->value.type().element_of();
Expand Down Expand Up @@ -1965,30 +1949,9 @@ void CodeGen_ARM::visit(const Load *op) {
}

if ((target_vscale() > 0)) {
if (stride && stride->value < 1) {
if (stride && stride->value <= 1) {
CodeGen_CPU::visit(op);
return;
} else if (stride && stride->value == 1) {
const int natural_lanes = natural_vector_size(op->type);
if (ramp->lanes % natural_lanes) {
// Load with lanes multiple of natural_lanes
int aligned_lanes = align_up(ramp->lanes, natural_lanes);
// Use predicate to prevent from overrun
Expr vpred;
if (is_predicated_load) {
vpred = Shuffle::make_concat({op->predicate, const_false(aligned_lanes - ramp->lanes)});
} else {
vpred = make_vector_predicate_1s_0s(ramp->lanes, aligned_lanes - ramp->lanes);
}
auto aligned_index = Ramp::make(ramp->base, stride, aligned_lanes);
auto aligned_type = op->type.with_lanes(aligned_lanes);
value = codegen(Load::make(aligned_type, op->name, aligned_index, op->image, op->param, vpred, op->alignment));
value = slice_vector(value, 0, ramp->lanes);
return;
} else {
CodeGen_CPU::visit(op);
return;
}
} else if (op->index.type().is_vector()) {
// General Gather Load

Expand Down
38 changes: 30 additions & 8 deletions test/correctness/simd_op_check_sve2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,24 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
// In case of lanes shorter than native's, predicate pattern is generated by
// "whilelt" intrinsic.
// <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i32(i32 0, i32 4)
if (factor == 0.5f) {
if (factor == 0.5f && bits >= 32) {
string constraint("vl" + to_string(total_lanes));
add("whilelt", {get_ptrue_instr_with_constraint(bits, constraint)}, total_lanes, scatter);
}
}
}

// Regression check for https://github.com/halide/Halide/pull/9120
if (has_sve()) {
constexpr float factor = 0.5f;
const int width = base_vec_bits * factor;
const int total_lanes = width / bits;
if (total_lanes / vscale < 2) continue; // bail out scalar and <vscale x 1 x ty>
AddTestFunctor add_absence(*this, bits, total_lanes, true, /* check_absense = */ true);

Expr load = in_im(x);
add_absence({{"uunpklo"}, {"uzp1"}}, load); // check those instrs do not exist
}
}
}

Expand Down Expand Up @@ -1038,6 +1050,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {

struct ArmTask {
vector<string> instrs;
bool check_absence;
};

struct Instruction {
Expand Down Expand Up @@ -1199,18 +1212,20 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
int default_bits,
int default_instr_lanes,
int default_vec_factor,
bool is_enabled = true /* false to skip testing */)
bool is_enabled = true, /* false to skip testing */
bool check_absence = false /* true to check the absence of the instruction pattern */)
: parent(p), default_bits(default_bits), default_instr_lanes(default_instr_lanes),
default_vec_factor(default_vec_factor), is_enabled(is_enabled) {};
default_vec_factor(default_vec_factor), is_enabled(is_enabled), check_absence(check_absence) {};

AddTestFunctor(SimdOpCheckArmSve &p,
int default_bits,
// default_instr_lanes is inferred from bits and vec_factor
int default_vec_factor,
bool is_enabled = true /* false to skip testing */)
bool is_enabled = true, /* false to skip testing */
bool check_absence = false /* true to check the absence of the instruction pattern */)
: parent(p), default_bits(default_bits),
default_instr_lanes(Instruction::get_instr_lanes(default_bits, default_vec_factor, p.target)),
default_vec_factor(default_vec_factor), is_enabled(is_enabled) {};
default_vec_factor(default_vec_factor), is_enabled(is_enabled), check_absence(check_absence) {};

// Constructs single Instruction with default parameters
void operator()(const string &opcode, Expr e) {
Expand Down Expand Up @@ -1320,14 +1335,15 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {

// Create Task and register
parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e});
parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)});
parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns), check_absence});
}

SimdOpCheckArmSve &parent;
int default_bits;
int default_instr_lanes;
int default_vec_factor;
bool is_enabled;
bool check_absence;
};

void compile_and_check(Func error, const string &op, const string &name, int vector_width, const std::vector<Argument> &arg_types, ostringstream &error_msg) override {
Expand All @@ -1353,7 +1369,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
assert(arm_task != arm_tasks.end());

std::ostringstream msg;
msg << op << " did not generate for target=" << target.to_string()
msg << op << " was not compiled as expected for target=" << target.to_string()
<< " vector_width=" << vector_width << ". Instead we got:\n";

string line;
Expand All @@ -1373,12 +1389,18 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

if (!patterns.empty()) {
if (!patterns.empty() && !arm_task->second.check_absence) {
error_msg << "Failed: " << msg.str() << "\n";
error_msg << "The following instruction patterns were not found:\n";
for (auto &p : patterns) {
error_msg << p << "\n";
}
} else if (patterns.empty() && arm_task->second.check_absence) {
error_msg << "Failed: " << msg.str() << "\n";
error_msg << "The following lines contain the instruction which shouldn't exist:\n";
for (auto &l : matched_lines) {
error_msg << l << "\n";
}
} else if (debug_mode == "1") {
for (auto &l : matched_lines) {
error_msg << " " << setw(20) << name << ", vf=" << setw(2) << vector_width << ", ";
Expand Down
Loading