diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index c16d6eb0..7abec228 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "../cudnn_frontend_Tensor.h" #include "../cudnn_frontend_Operation.h" @@ -18,30 +19,15 @@ namespace cudnn_frontend { namespace detail { -inline void -assign_uid(graph::Tensor_attributes* const tensor, - int64_t& potential_uid, - std::unordered_set const& used_uids) { - // get_next_potential_uid - while (used_uids.find(potential_uid) != used_uids.end()) { - ++potential_uid; - } - - tensor->set_uid(potential_uid); - ++potential_uid; // increment, as used its used now -} - // TODO: Always returns OK. Can the status and error message be accessed from tensor descriptor? inline error_t create_cudnn_tensor( std::shared_ptr const& props, - std::unordered_map>& tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) { - // Assign tensor a uid - if (props->has_uid() == false) { - assign_uid(props.get(), potential_uid, used_uids); - } + std::unordered_map>& tensors) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + props->has_uid() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + props->get_name() + "' UID not assigned before creating backend tensor."); // Check whether backend tensor already created auto tensor_uid = props->get_uid(); @@ -81,7 +67,7 @@ create_cudnn_tensor( } if (auto ragged_offset_props = props->get_ragged_offset()) { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors)); tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid())); } if (props->has_ragged_offset_multiplier()) { diff --git a/include/cudnn_frontend/graph_helpers.h b/include/cudnn_frontend/graph_helpers.h index d2ce9e33..879927d5 100644 --- a/include/cudnn_frontend/graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -529,7 +529,7 @@ log_dump_tensor_content(int64_t uid, default: data_str = to_hex(host_buf.data(), num_elements, elem_size); } - CUDNN_FE_LOG_LABEL_ENDL("Tensor Dump Uid: " << uid << " Name: " << name << " Data: " << data_str); + CUDNN_FE_LOG_LABEL_ENDL("Tensor Dump uid: " << uid << " Name: " << name << " Data: " << data_str); return {error_code_t::OK, ""}; } @@ -562,7 +562,7 @@ log_variant_pack_memory_type(int64_t uid, void* ptr) { }; // clang-format off - CUDNN_FE_LOG_LABEL_ENDL("Variant Pack" << std::setw(0) << " Uid: " << std::setw(20) << uid + CUDNN_FE_LOG_LABEL_ENDL("Variant Pack" << std::setw(0) << " uid: " << std::setw(20) << uid << std::setw(0) << " MemoryType: " << std::setw(12) << memory_type_to_string(attributes.type) << std::setw(0) << " Device: " << std::setw(4) << attributes.device << std::setw(0) << " UnifiedPtr: " << std::setw(20) << ptr_to_string(ptr) @@ -584,4 +584,4 @@ class cudnnGraphNotSupportedException : public std::runtime_error { } }; -} // namespace cudnn_frontend \ No newline at end of file +} // namespace cudnn_frontend diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 49168578..ca581e13 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -46,10 +46,21 @@ namespace cudnn_frontend::graph { class Graph : public ICudnn, public INode { private: + static constexpr char const *GRAPH_JSON_VERSION = "2.0"; + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + static error_t + check_graph_json_version(json const &j) { + RETURN_CUDNN_FRONTEND_ERROR_IF(j.value("json_version", std::string{}) != GRAPH_JSON_VERSION, + error_code_t::UNSUPPORTED_GRAPH_FORMAT, + "Unsupported graph JSON version. Expected " + std::string(GRAPH_JSON_VERSION)); + return {error_code_t::OK, ""}; + } +#endif + std::unordered_set> full_graph_inputs; - std::unordered_set used_uids; int64_t fe_workspace_size = 0; - uint64_t graph_uid; + uint64_t gid; std::unordered_set> deserialized_tensor_properties; std::unordered_map deserialized_pass_by_value; @@ -64,37 +75,76 @@ class Graph : public ICudnn, public INode { std::vector, char>> tensors_to_dump; error_t - get_pre_assigned_uids(std::unordered_set &used_uids) { - for (auto const &input : full_graph_inputs) { - if (input->has_uid()) { - auto uid = input->get_uid(); - auto iter = used_uids.find(uid); - RETURN_CUDNN_FRONTEND_ERROR_IF(iter != used_uids.end(), - error_code_t::INVALID_VALUE, - "uid " + std::to_string(uid) + " for tensor named " + input->get_name() + - " has been already assigned to another tensor."); - used_uids.insert(uid); + collect_assigned_uid(std::shared_ptr const &tensor, + std::unordered_set &used_uids, + std::unordered_set &visited_tensors) const { + if (tensor == nullptr || !visited_tensors.insert(tensor.get()).second) { + return {error_code_t::OK, ""}; + } + + if (tensor->has_uid()) { + auto uid = tensor->get_uid(); + RETURN_CUDNN_FRONTEND_ERROR_IF(used_uids.find(uid) != used_uids.end(), + error_code_t::INVALID_VALUE, + "uid " + std::to_string(uid) + " for tensor named " + tensor->get_name() + + " has been already assigned to another tensor."); + used_uids.insert(uid); + } + + CHECK_CUDNN_FRONTEND_ERROR(collect_assigned_uid(tensor->get_ragged_offset(), used_uids, visited_tensors)); + return {error_code_t::OK, ""}; + } + + error_t + assign_tensor_uid(std::shared_ptr const &tensor, + std::unordered_set &used_uids, + std::unordered_set &visited_tensors) const { + if (tensor == nullptr || !visited_tensors.insert(tensor.get()).second) { + return {error_code_t::OK, ""}; + } + + if (tensor->has_uid() == false) { + Tensor_attributes::uid_t uid = 1; + while (used_uids.find(uid) != used_uids.end()) { + ++uid; } + tensor->set_uid(uid); + used_uids.insert(uid); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Assigned UID " << tensor->get_uid() << " to tensor named '" + << tensor->get_name() << "'."); + } + + CHECK_CUDNN_FRONTEND_ERROR(assign_tensor_uid(tensor->get_ragged_offset(), used_uids, visited_tensors)); + return {error_code_t::OK, ""}; + } + + error_t + assign_uids() const { + std::unordered_set used_uids; + std::unordered_set visited_tensors; + std::vector> tensors; + + CHECK_CUDNN_FRONTEND_ERROR(collect_tensor_attributes_subtree(tensors)); + for (auto const &input : full_graph_inputs) { + tensors.push_back(input); } for (auto const &output : full_graph_outputs) { - if (output->has_uid()) { - auto uid = output->get_uid(); - auto iter = used_uids.find(uid); - RETURN_CUDNN_FRONTEND_ERROR_IF(iter != used_uids.end(), - error_code_t::INVALID_VALUE, - "uid " + std::to_string(uid) + " for tensor named " + - output->get_name() + - " has been already assigned to another tensor."); - used_uids.insert(uid); - } + tensors.push_back(output); + } + + for (auto const &tensor : tensors) { + CHECK_CUDNN_FRONTEND_ERROR(collect_assigned_uid(tensor, used_uids, visited_tensors)); + } + visited_tensors.clear(); + for (auto const &tensor : tensors) { + CHECK_CUDNN_FRONTEND_ERROR(assign_tensor_uid(tensor, used_uids, visited_tensors)); } return {error_code_t::OK, ""}; } error_t - log_tensors_to_dump_(cudnnHandle_t handle, - std::unordered_map const &tensor_uid_to_pointer_map) const { + log_tensor_dumps(cudnnHandle_t handle, std::unordered_map const &tensor_uid_to_pointer_map) const { if (!isLoggingTensorDumpEnabled()) { return {error_code_t::OK, ""}; } @@ -120,6 +170,23 @@ class Graph : public ICudnn, public INode { return {error_code_t::OK, ""}; } + error_t + log_tensor_dumps(cudnnHandle_t handle, std::vector const &tensor_uids, void *const *tensor_ptrs) const { + if (!isLoggingTensorDumpEnabled()) { + return {error_code_t::OK, ""}; + } + + std::unordered_map tensor_uid_to_pointer_map; + tensor_uid_to_pointer_map.reserve(tensor_uids.size()); + for (size_t i = 0; i < tensor_uids.size(); i++) { + if (tensor_ptrs[i] != nullptr) { + tensor_uid_to_pointer_map.emplace(tensor_uids[i], tensor_ptrs[i]); + } + } + + return log_tensor_dumps(handle, tensor_uid_to_pointer_map); + } + error_t pre_validate_node() const override final { RETURN_CUDNN_FRONTEND_ERROR_IF( @@ -168,9 +235,8 @@ class Graph : public ICudnn, public INode { } virtual error_t - create_cudnn_tensors_node(std::unordered_map> &, - int64_t &, - std::unordered_set const &) const override final { + create_cudnn_tensors_node( + std::unordered_map> &) const override final { return {error_code_t::OK, ""}; } @@ -272,7 +338,7 @@ class Graph : public ICudnn, public INode { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j; serialize(j); - j.erase("graph_uid"); + j.erase("gid"); if (remove_shape) { for (auto &tensor : j["tensors"]) { tensor["dim"].clear(); @@ -602,8 +668,8 @@ class Graph : public ICudnn, public INode { public: Graph() : INode(detail::Context{}) { - static std::atomic next_graph_uid{1}; - graph_uid = next_graph_uid.fetch_add(1, std::memory_order_relaxed); + static std::atomic next_gid{1}; + gid = next_gid.fetch_add(1, std::memory_order_relaxed); } error_t @@ -902,6 +968,7 @@ class Graph : public ICudnn, public INode { error_t validate() { CUDNN_FE_LOG_BANNER(" VALIDATING GRAPH "); + CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); CUDNN_FE_LOG(*this << std::endl;); // First validate all inputs that the user set. @@ -916,11 +983,6 @@ class Graph : public ICudnn, public INode { CHECK_CUDNN_FRONTEND_ERROR(output->validate()); } - // Get all the pre assigned uids - CHECK_CUDNN_FRONTEND_ERROR(get_pre_assigned_uids(used_uids)); - // Clear state - used_uids.clear(); - CUDNN_FE_LOG_BANNER(" VALIDATED ALL OK "); return {error_code_t::OK, ""}; @@ -946,14 +1008,11 @@ class Graph : public ICudnn, public INode { // expand composite nodes CHECK_CUDNN_FRONTEND_ERROR(expand_subtree()); - - // Get all the pre assigned uids - CHECK_CUDNN_FRONTEND_ERROR(get_pre_assigned_uids(used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); CUDNN_FE_LOG_BANNER(" 2/4 CREATE TENSORS "); - Tensor_attributes::uid_t start_uid = 1; - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_subtree(uid_to_tensors, start_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_subtree(uid_to_tensors)); tensors_to_dump.clear(); CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_to_dump_subtree(tensors_to_dump)); @@ -1424,7 +1483,8 @@ class Graph : public ICudnn, public INode { // 4. Run auxiliary kernels (e.g. SDPA reduction accumulator init) CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); - CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing graph_uid " << graph_uid); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing gid " << gid); + CHECK_CUDNN_FRONTEND_ERROR(log_tensor_dumps(handle, varpack_template.all_uids, ptrs)); // 5. Dispatch void *engine_workspace = static_cast(workspace) + fe_workspace_size; @@ -1567,6 +1627,7 @@ class Graph : public ICudnn, public INode { serialize(std::vector &data) const { CUDNN_FE_LOG_BANNER(" SERIALIZE PLAN "); #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); json j; serialize(j); @@ -1632,6 +1693,7 @@ class Graph : public ICudnn, public INode { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j = json::from_ubjson(data); + CHECK_CUDNN_FRONTEND_ERROR(check_graph_json_version(j)); // Clear deserialize-owned containers so a re-deserialize on the same Graph // does not feed prepare_variant_pack_template() with stale entries from a @@ -1641,17 +1703,12 @@ class Graph : public ICudnn, public INode { deserialized_workspace_modifications.clear(); tensors_to_dump.clear(); - if (j.contains("graph_uid") && !j["graph_uid"].is_null()) { - graph_uid = j["graph_uid"].get(); - } + gid = j["gid"].get(); - if (j.contains("tensors")) { - auto tensor_map = j["tensors"].get>(); - for (const auto &tensor_info : tensor_map) { - auto tensor_attributes = std::make_shared(); - from_json(tensor_info.second, *tensor_attributes); - deserialized_tensor_properties.insert(tensor_attributes); - } + for (const auto &tensor_info : j.at("tensors")) { + auto tensor_attributes = std::make_shared(); + from_json(tensor_info, *tensor_attributes); + deserialized_tensor_properties.insert(tensor_attributes); } auto serialized_plan = j["cudnn_backend_data"]; @@ -2324,18 +2381,25 @@ class Graph : public ICudnn, public INode { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json &j) const override final { + auto status = assign_uids(); + if (status.is_bad()) { + throw std::runtime_error(status.get_message()); + } + // Different from serialization of other INodes. // Go over each subnode and serialize them. json full_json; - full_json["context"]["name"] = context.get_name(); + full_json["gid"] = gid; full_json["context"]["compute_data_type"] = context.get_compute_data_type(); full_json["context"]["intermediate_data_type"] = context.get_intermediate_data_type(); full_json["context"]["io_data_type"] = context.get_io_data_type(); full_json["context"]["sm_count"] = context.get_target_sm_count(); full_json["context"]["is_dynamic_shape_enabled"] = context.get_dynamic_shape_enabled(); full_json["context"]["is_override_shape_enabled"] = context.get_override_shape_enabled(); - full_json["graph_uid"] = graph_uid; + if (!context.get_name().empty()) { + full_json["context"]["name"] = context.get_name(); + } full_json.update(R"( {"tag": "GRAPH"})"_json); full_json["nodes"]; @@ -2345,112 +2409,57 @@ class Graph : public ICudnn, public INode { full_json["nodes"].push_back(j_sub_node); } - j["context"] = full_json["context"]; - j["graph_uid"] = full_json["graph_uid"]; + j["gid"] = full_json["gid"]; + j["context"] = full_json["context"]; - j["json_version"] = "1.0"; + j["json_version"] = GRAPH_JSON_VERSION; j["cudnn_backend_version"] = detail::get_backend_version_string(); j["cudnn_frontend_version"] = CUDNN_FRONTEND_VERSION; - j["nodes"]; - j["tensors"]; - std::unordered_set tensors; + j["nodes"] = json::array(); + j["tensors"] = json::array(); + std::map tensors; + auto add_tensor = [&](json &refs, std::string const &port_name, json const &tensor_info) { + if (tensor_info.is_null()) { + return; + } + auto tensor_uid = tensor_info.at("uid").get(); + refs[port_name] = tensor_uid; + tensors.emplace(tensor_uid, tensor_info); + }; for (const auto &sub_node : full_json["nodes"]) { - // Create a short version of the node auto short_node = sub_node; - short_node["inputs"] = {}; - short_node["outputs"] = {}; + short_node["inputs"] = json::object(); + short_node["outputs"] = json::object(); auto node_name = sub_node["tag"].get(); auto i = 0; - // Process node inputs for (const auto &input : sub_node["inputs"]) { std::string port_name; json tensor_info; if (node_name == "CONCATENATE") { - // Extract port_name and tensor_name port_name = std::to_string(i); tensor_info = input; i++; } else { - // Extract port_name and tensor_name port_name = input[0].get(); tensor_info = input[1]; } - if (tensor_info.is_null()) { - continue; - } - - // Determine the key to use for this tensor - std::string tensor_key; - json tensor_ref; - bool uid_assigned = tensor_info.contains("uid_assigned") && tensor_info["uid_assigned"].get(); - - if (uid_assigned && tensor_info.contains("uid") && tensor_info["uid"].is_number_integer()) { - // Use numeric UID if it was explicitly assigned - int64_t tensor_uid = tensor_info["uid"].get(); - tensor_key = std::to_string(tensor_uid); - tensor_ref = json(tensor_uid); - } else if (tensor_info.contains("name")) { - // Fall back to tensor name if UID not assigned - tensor_key = tensor_info["name"].get(); - tensor_ref = tensor_key; - } else { - continue; - } - - // Update short_node inputs - short_node["inputs"][port_name] = tensor_ref; - - // Check if the tensor is already in the tensors map - if (tensors.find(tensor_key) == tensors.end()) { - // If not, add it to the j["tensors"] - j["tensors"][tensor_key] = tensor_info; - } + add_tensor(short_node["inputs"], port_name, tensor_info); } - // Process node outputs for (const auto &output : sub_node["outputs"]) { - // Extract port_name and tensor_name auto port_name = output[0].get(); auto tensor_info = output[1]; - - if (tensor_info.is_null()) { - continue; - } - - // Determine the key to use for this tensor - std::string tensor_key; - json tensor_ref; - bool uid_assigned = tensor_info.contains("uid_assigned") && tensor_info["uid_assigned"].get(); - - if (uid_assigned && tensor_info.contains("uid") && tensor_info["uid"].is_number_integer()) { - // Use numeric UID if it was explicitly assigned - int64_t tensor_uid = tensor_info["uid"].get(); - tensor_key = std::to_string(tensor_uid); - tensor_ref = json(tensor_uid); - } else if (tensor_info.contains("name")) { - // Fall back to tensor name if UID not assigned - tensor_key = tensor_info["name"].get(); - tensor_ref = tensor_key; - } else { - continue; - } - - // Update short_node outputs - short_node["outputs"][port_name] = tensor_ref; - - // Check if the tensor is already in the tensors map - if (tensors.find(tensor_key) == tensors.end()) { - // If not, add it to the j["tensors"] - j["tensors"][tensor_key] = tensor_info; - } + add_tensor(short_node["outputs"], port_name, tensor_info); } - // Add the short_node to j["nodes"] j["nodes"].push_back(short_node); } + for (auto const &tensor : tensors) { + j["tensors"].push_back(tensor.second); + } }; #endif @@ -2463,6 +2472,8 @@ class Graph : public ICudnn, public INode { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB error_t deserialize(const json &j) { + CHECK_CUDNN_FRONTEND_ERROR(check_graph_json_version(j)); + if (j.contains("context")) { const auto &j_context = j["context"]; if (j_context.contains("compute_data_type") && !j_context["compute_data_type"].is_null()) { @@ -2488,51 +2499,51 @@ class Graph : public ICudnn, public INode { } } - if (j.contains("graph_uid") && !j["graph_uid"].is_null()) { - graph_uid = j["graph_uid"].get(); + gid = j["gid"].get(); + + std::map tensor_table; + for (auto const &tensor_info : j.at("tensors")) { + tensor_table[tensor_info.at("uid").get()] = tensor_info; + } + + std::map> created_tensors; + for (auto const &[uid, tensor_info] : tensor_table) { + auto tensor_attributes = std::make_shared(); + from_json(tensor_info, *tensor_attributes); + created_tensors[uid] = tensor_attributes; + } + + std::vector, std::shared_ptr>> ragged_offsets; + for (auto const &[_, tensor] : created_tensors) { + auto ragged_offset = tensor->get_ragged_offset(); + if (ragged_offset != nullptr) { + ragged_offsets.emplace_back(tensor, ragged_offset); + } + } + for (auto const &[tensor, ragged_offset] : ragged_offsets) { + auto [created_tensor, inserted] = created_tensors.emplace(ragged_offset->get_uid(), ragged_offset); + (void)inserted; + tensor->set_ragged_offset(created_tensor->second); } - std::map> created_tensors; - // Iterate through each sub-node in the full JSON + auto fill_tensor_refs = [&tensor_table](json const &tensor_refs, json &tensor_infos) -> error_t { + if (!tensor_refs.is_object()) { + return {error_code_t::OK, ""}; + } + for (auto &[port_name, tensor_ref] : tensor_refs.items()) { + tensor_infos.push_back({port_name, tensor_table.at(tensor_ref.get())}); + } + return {error_code_t::OK, ""}; + }; + if (j.contains("nodes") && j["nodes"].is_array()) { for (auto j_sub_node : j["nodes"]) { - // Create a JSON object for inputs json inputs; + CHECK_CUDNN_FRONTEND_ERROR(fill_tensor_refs(j_sub_node["inputs"], inputs)); - // Iterate through each input of the sub-node - if (j_sub_node.contains("inputs") && j_sub_node["inputs"].is_object()) { - for (auto &[port_name, tensor_ref] : j_sub_node["inputs"].items()) { - // Convert tensor reference (either numeric UID or string name) to string key - std::string tensor_key = tensor_ref.is_number_integer() - ? std::to_string(tensor_ref.get()) - : tensor_ref.get(); - - if (j.contains("tensors") && j["tensors"].contains(tensor_key)) { - // Add the input to the inputs JSON object - inputs.push_back({port_name, j["tensors"][tensor_key]}); - } - } - } - - // Create a JSON object for outputs json outputs; + CHECK_CUDNN_FRONTEND_ERROR(fill_tensor_refs(j_sub_node["outputs"], outputs)); - // Iterate through each output of the sub-node - if (j_sub_node.contains("outputs") && j_sub_node["outputs"].is_object()) { - for (auto &[port_name, tensor_ref] : j_sub_node["outputs"].items()) { - // Convert tensor reference (either numeric UID or string name) to string key - std::string tensor_key = tensor_ref.is_number_integer() - ? std::to_string(tensor_ref.get()) - : tensor_ref.get(); - - if (j.contains("tensors") && j["tensors"].contains(tensor_key)) { - // Add the output to the outputs JSON object - outputs.push_back({port_name, j["tensors"][tensor_key]}); - } - } - } - - // Replace the original inputs and outputs of the sub-node with the new JSON objects j_sub_node["inputs"] = inputs; j_sub_node["outputs"] = outputs; @@ -2541,12 +2552,10 @@ class Graph : public ICudnn, public INode { return t; } - if (created_tensors.find(t->get_name()) == created_tensors.end()) { - created_tensors.insert({t->get_name(), t}); - return t; - } else { - return created_tensors[t->get_name()]; - } + auto const uid = t->get_uid(); + auto [created_tensor, inserted] = created_tensors.emplace(uid, t); + (void)inserted; + return created_tensor->second; }; #define CHECK_TENSORS(attributes) \ diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index c5dd2a41..f6e1c4ad 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -123,8 +123,7 @@ class Tensor_attributes { std::optional compile_time_constant_value = std::nullopt; TensorReordering_t reordering_type = TensorReordering_t::NONE; - uid_t uid = 0; - bool uid_assigned = false; + std::optional uid = std::nullopt; std::shared_ptr ragged_offset; int64_t ragged_offset_multiplier = 1; @@ -440,25 +439,23 @@ class Tensor_attributes { uid_t get_uid() const { - return uid; + return uid.value_or(0); } - uid_t + bool has_uid() const { - return uid_assigned; + return uid.has_value(); } auto clear_uid(void) -> Tensor_attributes& { - uid = 0; - uid_assigned = false; + uid = std::nullopt; return *this; } auto set_uid(uid_t value) -> Tensor_attributes& { - uid = value; - uid_assigned = true; + uid = value; return *this; } @@ -549,6 +546,32 @@ class Attributes { } public: + void + fill_tensors(std::vector>& tensors) const { + auto derived = static_cast(this); + + if constexpr (std::is_same_v) { + for (auto& tensor : derived->inputs) { + tensors.push_back(tensor); + } + } else { + for (auto& [name, tensor] : derived->inputs) { + (void)name; + tensors.push_back(tensor); + } + } + for (auto& [name, tensor] : derived->outputs) { + (void)name; + tensors.push_back(tensor); + } + if constexpr (std::is_same_v || + std::is_same_v) { + for (auto& tensor : derived->peer_stats) { + tensors.push_back(tensor); + } + } + } + error_t fill_pass_by_value(std::unordered_map& tensor_to_pass_by_value) const { diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index 6e14aab0..93ef173e 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -2182,6 +2182,48 @@ class CompositeSDPABackwardNode : public NodeCRTP { return {error_code_t::OK, ""}; } + error_t + collect_tensors_to_dump_node( + std::vector, char>>& tensors_to_dump) const override final { + std::unordered_set seen_uids; + auto add_tensor = [&tensors_to_dump, &seen_uids](std::shared_ptr const& tensor) { + if (tensor != nullptr && seen_uids.insert(tensor->get_uid()).second) { + tensors_to_dump.emplace_back(tensor, 'd'); + } + }; + auto add_input = [&](input_names name) { + auto it = attributes.inputs.find(name); + if (it != attributes.inputs.end()) { + add_tensor(it->second); + } + }; + auto add_input_offset = [&](input_names name) { + auto it = attributes.inputs.find(name); + if (it != attributes.inputs.end() && it->second != nullptr) { + add_tensor(it->second->get_ragged_offset()); + } + }; + auto add_output_offset = [&](output_names name) { + auto it = attributes.outputs.find(name); + if (it != attributes.outputs.end() && it->second != nullptr) { + add_tensor(it->second->get_ragged_offset()); + } + }; + + add_input(input_names::SEQ_LEN_Q); + add_input(input_names::SEQ_LEN_KV); + + for (auto name : + {input_names::Q, input_names::K, input_names::V, input_names::O, input_names::dO, input_names::Stats}) { + add_input_offset(name); + } + for (auto name : {output_names::dQ, output_names::dK, output_names::dV}) { + add_output_offset(name); + } + + return {error_code_t::OK, ""}; + } + #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { diff --git a/include/cudnn_frontend/node/slice.h b/include/cudnn_frontend/node/slice.h index a1aafb16..10503b62 100644 --- a/include/cudnn_frontend/node/slice.h +++ b/include/cudnn_frontend/node/slice.h @@ -78,25 +78,21 @@ class SliceNode : public NodeCRTP { } error_t - create_cudnn_tensors_node(std::unordered_map>& tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) const override final { + create_cudnn_tensors_node( + std::unordered_map>& tensors) const override final { getLogger() << "[cudnn_frontend] INFO: Creating cudnn tensors for SliceNode " << attributes.name << std::endl; auto const input = attributes.inputs.at(Slice_attributes::input_names::X); auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); if (detail::get_backend_version() >= 92200 && detail::get_compiled_version() >= 92200) { - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(input, tensors, potential_uid, used_uids)); - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(input, tensors)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors)); return {error_code_t::OK, ""}; } - if (input->has_uid() == false) { - detail::assign_uid(input.get(), potential_uid, used_uids); - } output->set_is_virtual(false); - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors)); return {error_code_t::OK, ""}; } @@ -221,4 +217,4 @@ class SliceNode : public NodeCRTP { #endif }; -} // namespace cudnn_frontend::graph \ No newline at end of file +} // namespace cudnn_frontend::graph diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index 8d46ece3..dad9b598 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -38,6 +39,7 @@ class CompositeSoftmaxNode; class UnifiedSoftmaxNode; class MoeGroupedMatmulNode; class UnifiedDiagonalBandMaskNode; + class TransposeNode; class SliceNode; @@ -100,11 +102,14 @@ class INode { return {error_code_t::OK, ""}; }; + virtual error_t + collect_tensor_attributes_node(std::vector>&) const { + return {error_code_t::OK, ""}; + }; + virtual error_t create_cudnn_tensors_node( - std::unordered_map>& uid_to_backend_tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) const = 0; + std::unordered_map>& uid_to_backend_tensors) const = 0; virtual error_t collect_tensors_in_workspace_node( @@ -299,16 +304,22 @@ class INode { return {error_code_t::OK, ""}; } + error_t + collect_tensor_attributes_subtree(std::vector>& tensors) const { + CHECK_CUDNN_FRONTEND_ERROR(collect_tensor_attributes_node(tensors)); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->collect_tensor_attributes_subtree(tensors)); + } + return {error_code_t::OK, ""}; + } + // Creates cudnn tensors for each node (and its sub nodes) error_t create_cudnn_tensors_subtree( - std::unordered_map>& uid_to_backend_tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) const { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_node(uid_to_backend_tensors, potential_uid, used_uids)); + std::unordered_map>& uid_to_backend_tensors) const { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_node(uid_to_backend_tensors)); for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR( - sub_node->create_cudnn_tensors_subtree(uid_to_backend_tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(sub_node->create_cudnn_tensors_subtree(uid_to_backend_tensors)); } return {error_code_t::OK, ""}; } @@ -482,40 +493,21 @@ class NodeCRTP : public INode { } error_t - create_cudnn_tensors_node(std::unordered_map>& tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) const override { - CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating cudnn tensors for node named '" << self().attributes.name << "':"); + collect_tensor_attributes_node(std::vector>& tensors) const override { + self().attributes.fill_tensors(tensors); + return {error_code_t::OK, ""}; + } - if constexpr (std::is_same_v) { - for (auto const& tensor : self().attributes.inputs) { - if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); - } - } - } else { - for (auto const& [name, tensor] : self().attributes.inputs) { - (void)name; - if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); - } - } - } + error_t + create_cudnn_tensors_node( + std::unordered_map>& tensors) const override { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating cudnn tensors for node named '" << self().attributes.name << "':"); + std::vector> node_tensors; + self().attributes.fill_tensors(node_tensors); - for (auto const& [name, tensor] : self().attributes.outputs) { - (void)name; + for (auto const& tensor : node_tensors) { if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); - } - } - - // Handle special case of BN where peer_stats is also an input - if constexpr (std::is_same_v || std::is_same_v) { - // Special case in BN where peer stats is also an input but is not present in inputs map - for (auto const& tensor : self().attributes.peer_stats) { - if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); - } + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors)); } } diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index a9061ffe..ce3d7806 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -592,19 +592,22 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Moe_grouped_matmul_bwd_attributes::output_names, inline void to_json(nlohmann::json& j, const Tensor_attributes& ta) { - j = nlohmann::json{{"name", ta.name}, - {"data_type", ta.data_type}, + j = nlohmann::json{{"data_type", ta.data_type}, {"dim", ta.dim}, {"stride", ta.stride}, {"is_virtual", ta.is_virtual}, {"pass_by_value", ta.pass_by_value}, {"is_pass_by_value", ta.is_pass_by_value}, {"reordering_type", ta.reordering_type}, - {"uid", ta.uid}, - {"uid_assigned", ta.uid_assigned}}; + {"uid", ta.get_uid()}}; + if (!ta.name.empty()) { + j["name"] = ta.name; + } if (ta.ragged_offset) { - j["ragged_offset_uid"] = ta.ragged_offset->get_uid(); - j["ragged_offset_name"] = ta.ragged_offset->get_name(); + j["ragged_offset_uid"] = ta.ragged_offset->get_uid(); + if (!ta.ragged_offset->get_name().empty()) { + j["ragged_offset_name"] = ta.ragged_offset->get_name(); + } } if (ta.has_ragged_offset_multiplier()) { j["ragged_offset_multiplier"] = ta.ragged_offset_multiplier; @@ -613,23 +616,21 @@ to_json(nlohmann::json& j, const Tensor_attributes& ta) { inline void from_json(const nlohmann::json& j, Tensor_attributes& ta) { - ta.name = j.at("name").get(); + ta.name = j.contains("name") && !j["name"].is_null() ? j.at("name").get() : ""; ta.data_type = j.at("data_type").get(); ta.dim = j.at("dim").get>(); ta.stride = j.at("stride").get>(); ta.is_virtual = j.at("is_virtual").get(); ta.is_pass_by_value = j.at("is_pass_by_value").get(); ta.reordering_type = j.at("reordering_type").get(); - ta.uid = j.at("uid").get(); - ta.uid_assigned = j.at("uid_assigned").get(); + ta.set_uid(j.at("uid").get()); if (ta.is_pass_by_value && !j["pass_by_value"].is_null()) { ta.pass_by_value = j.at("pass_by_value"); } if (j.contains("ragged_offset_uid")) { - auto ragged_offset = std::make_shared(); - ragged_offset->uid = j.at("ragged_offset_uid").get(); - ragged_offset->uid_assigned = true; + auto ragged_offset = std::make_shared(); + ragged_offset->set_uid(j.at("ragged_offset_uid").get()); if (j.contains("ragged_offset_name")) { ragged_offset->name = j.at("ragged_offset_name").get(); } diff --git a/test/python/sdpa/fp8.py b/test/python/sdpa/fp8.py index 92e954d4..6c6a59bf 100644 --- a/test/python/sdpa/fp8.py +++ b/test/python/sdpa/fp8.py @@ -293,6 +293,7 @@ def create_paged_container_and_block_table(tensor, block_size): def exec_sdpa_fp8(cfg, request, cudnn_handle): if request.config.option.dryrun: pytest.skip("dryrun") + perf = request.config.getoption("--perf") cudnn_version = LooseVersion(cudnn.backend_version_string()) if cudnn_version < "9.14.0": @@ -385,12 +386,15 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): else: padding = None - o_ref, stats_ref, o_amax = compute_ref(q_fp8, k_fp8, v_fp8, attn_scale=attn_scale, - q_descale=q_descale_gpu, k_descale=k_descale_gpu, v_descale=v_descale_gpu, - s_scale=s_scale_gpu, s_descale=s_descale_gpu, torch_itype=torch_itype, - torch_otype=torch_otype, padding=padding, - left_bound=left_bound, right_bound=right_bound, diag_align=diag_align, - sink_token=sink_token_gpu, rescale_threshold=rescale_threshold) + if perf: + o_amax = 1.0 + else: + o_ref, stats_ref, o_amax = compute_ref(q_fp8, k_fp8, v_fp8, attn_scale=attn_scale, + q_descale=q_descale_gpu, k_descale=k_descale_gpu, v_descale=v_descale_gpu, + s_scale=s_scale_gpu, s_descale=s_descale_gpu, torch_itype=torch_itype, + torch_otype=torch_otype, padding=padding, + left_bound=left_bound, right_bound=right_bound, diag_align=diag_align, + sink_token=sink_token_gpu, rescale_threshold=rescale_threshold) o_scale_gpu = torch.tensor([get_fp8_scale_factor(o_amax, torch_otype)], dtype=torch.float, device="cuda") @@ -458,7 +462,7 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): variant_pack[int(GraphFwdUid.sink_token)] = sink_token_gpu workspace = torch.empty(graph_fwd.get_workspace_size(), dtype=torch.uint8, device="cuda") - if request.config.getoption("--perf"): + if perf: times_ms = time_execution(graph_fwd.execute, variant_pack, workspace, cudnn_handle) print(f"@@@@ FP8 Fwd graph_fwd.execute avg_time_ms={times_ms.mean().item():.3f}") profile_execution(graph_fwd.execute, variant_pack, workspace, cudnn_handle) @@ -466,21 +470,22 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): torch.cuda.synchronize() # Compare forward output - if is_ragged: - o_ref_comp = convert_uniform_to_packed(torch.einsum("bshd->bhsd", o_ref), seq_len_q_ref, max_t_q) - else: - o_ref_comp = o_ref + if not perf: + if is_ragged: + o_ref_comp = convert_uniform_to_packed(torch.einsum("bshd->bhsd", o_ref), seq_len_q_ref, max_t_q) + else: + o_ref_comp = o_ref - o_gpu_float = o_gpu.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) - o_ref_float = o_ref_comp.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) + o_gpu_float = o_gpu.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) + o_ref_float = o_ref_comp.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) - if is_ragged: - t_idx = sum(seq_len_q_list) - o_gpu_float[t_idx:] = 0 - o_ref_float[t_idx:] = 0 + if is_ragged: + t_idx = sum(seq_len_q_list) + o_gpu_float[t_idx:] = 0 + o_ref_float[t_idx:] = 0 - atol, rtol = 0.08, 0.2 - torch.testing.assert_close(o_gpu_float, o_ref_float, atol=atol, rtol=rtol) + atol, rtol = 0.08, 0.2 + torch.testing.assert_close(o_gpu_float, o_ref_float, atol=atol, rtol=rtol) # Backward pass if not cfg.is_infer: @@ -491,30 +496,36 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): o_descale_gpu = torch.tensor([get_fp8_descale_factor(o_amax, torch_otype)], dtype=torch.float, device="cuda") dO_descale_gpu = torch.tensor([get_fp8_descale_factor(dO_amax, torch_itype)], dtype=torch.float, device="cuda") - # Get unpacked BSHD references for backward - if is_ragged: - q_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(q_gpu, seq_len_q_ref, s_qo)) - k_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(k_gpu, seq_len_kv_ref, s_kv)) - v_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(v_gpu, seq_len_kv_ref, s_kv)) - o_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(o_gpu, seq_len_q_ref, s_qo)) - dO_ref_bwd = dO_fp8 + if perf: + dP_amax = 1.0 + dQ_amax = 1.0 + dK_amax = 1.0 + dV_amax = 1.0 else: - q_ref_bwd = q_gpu - k_ref_bwd = k_gpu - v_ref_bwd = v_gpu - o_ref_bwd = o_gpu - dO_ref_bwd = dO_fp8 - - padding_bwd = (seq_len_q_ref, seq_len_kv_ref) if is_ragged else None - dQ_ref, dK_ref, dV_ref, dSink_token_ref, dP_amax, dQ_amax, dK_amax, dV_amax = compute_ref_backward( - q_ref_bwd, k_ref_bwd, v_ref_bwd, o_ref_bwd, dO_ref_bwd, attn_scale=attn_scale, - q_descale=q_descale_gpu, k_descale=k_descale_gpu, v_descale=v_descale_gpu, - s_scale=s_scale_gpu, s_descale=s_descale_gpu, torch_itype=torch_itype, - o_descale=o_descale_gpu, dO_descale=dO_descale_gpu, - torch_otype=torch_otype, padding=padding_bwd, - left_bound=left_bound, right_bound=right_bound, diag_align=diag_align, - sink_token=sink_token_gpu - ) + # Get unpacked BSHD references for backward + if is_ragged: + q_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(q_gpu, seq_len_q_ref, s_qo)) + k_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(k_gpu, seq_len_kv_ref, s_kv)) + v_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(v_gpu, seq_len_kv_ref, s_kv)) + o_ref_bwd = torch.einsum("bhsd->bshd", convert_packed_to_uniform(o_gpu, seq_len_q_ref, s_qo)) + dO_ref_bwd = dO_fp8 + else: + q_ref_bwd = q_gpu + k_ref_bwd = k_gpu + v_ref_bwd = v_gpu + o_ref_bwd = o_gpu + dO_ref_bwd = dO_fp8 + + padding_bwd = (seq_len_q_ref, seq_len_kv_ref) if is_ragged else None + dQ_ref, dK_ref, dV_ref, dSink_token_ref, dP_amax, dQ_amax, dK_amax, dV_amax = compute_ref_backward( + q_ref_bwd, k_ref_bwd, v_ref_bwd, o_ref_bwd, dO_ref_bwd, attn_scale=attn_scale, + q_descale=q_descale_gpu, k_descale=k_descale_gpu, v_descale=v_descale_gpu, + s_scale=s_scale_gpu, s_descale=s_descale_gpu, torch_itype=torch_itype, + o_descale=o_descale_gpu, dO_descale=dO_descale_gpu, + torch_otype=torch_otype, padding=padding_bwd, + left_bound=left_bound, right_bound=right_bound, diag_align=diag_align, + sink_token=sink_token_gpu + ) dP_descale_gpu = torch.tensor([get_fp8_descale_factor(dP_amax, torch_itype)], dtype=torch.float, device="cuda") dQ_scale_gpu = torch.tensor([get_fp8_scale_factor(dQ_amax, torch_otype)], dtype=torch.float, device="cuda") @@ -584,7 +595,7 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): variant_pack_bwd[int(GraphBwdUid.dSink_token)] = dSink_token_gpu workspace_bwd = torch.empty(graph_bwd.get_workspace_size(), dtype=torch.uint8, device="cuda") - if request.config.getoption("--perf"): + if perf: times_ms = time_execution(graph_bwd.execute, variant_pack_bwd, workspace_bwd, cudnn_handle) print(f"@@@@ FP8 Bwd graph.execute avg_time_ms={times_ms.mean().item():.3f}") profile_execution(graph_bwd.execute, variant_pack_bwd, workspace_bwd, cudnn_handle) @@ -613,36 +624,37 @@ def exec_sdpa_fp8(cfg, request, cudnn_handle): pytest.fail("determinism check failed", pytrace=False) print("@@@@ Determinism check: PASSED, dQ, dK, dV bitwise match between runs.") - if is_ragged: - dQ_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dQ_ref), seq_len_q_ref, max_t_q) - dK_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dK_ref), seq_len_kv_ref, max_t_kv) - dV_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dV_ref), seq_len_kv_ref, max_t_kv) - - dQ_out = dQ_gpu.detach().float() * get_fp8_descale_factor(dQ_amax, torch_otype) - dK_out = dK_gpu.detach().float() * get_fp8_descale_factor(dK_amax, torch_otype) - dV_out = dV_gpu.detach().float() * get_fp8_descale_factor(dV_amax, torch_otype) - - dQ_ref_float = dQ_ref.detach().float() * get_fp8_descale_factor(dQ_amax, torch_otype) - dK_ref_float = dK_ref.detach().float() * get_fp8_descale_factor(dK_amax, torch_otype) - dV_ref_float = dV_ref.detach().float() * get_fp8_descale_factor(dV_amax, torch_otype) - - if is_ragged: - t_idx_q = sum(seq_len_q_list) - dQ_out[t_idx_q:] = 0 - dQ_ref_float[t_idx_q:] = 0 - t_idx_kv = sum(seq_len_kv_list) - dK_out[t_idx_kv:] = 0 - dK_ref_float[t_idx_kv:] = 0 - dV_out[t_idx_kv:] = 0 - dV_ref_float[t_idx_kv:] = 0 - - atol, rtol = 0.04, 0.2 - torch.testing.assert_close(dQ_out, dQ_ref_float, atol=atol, rtol=rtol) - torch.testing.assert_close(dK_out, dK_ref_float, atol=atol, rtol=rtol) - torch.testing.assert_close(dV_out, dV_ref_float, atol=atol, rtol=rtol) - - if with_sink_token: - torch.testing.assert_close(dSink_token_gpu, dSink_token_ref, atol=0.02, rtol=0.2) + if not perf: + if is_ragged: + dQ_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dQ_ref), seq_len_q_ref, max_t_q) + dK_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dK_ref), seq_len_kv_ref, max_t_kv) + dV_ref = convert_uniform_to_packed(torch.einsum("bshd->bhsd", dV_ref), seq_len_kv_ref, max_t_kv) + + dQ_out = dQ_gpu.detach().float() * get_fp8_descale_factor(dQ_amax, torch_otype) + dK_out = dK_gpu.detach().float() * get_fp8_descale_factor(dK_amax, torch_otype) + dV_out = dV_gpu.detach().float() * get_fp8_descale_factor(dV_amax, torch_otype) + + dQ_ref_float = dQ_ref.detach().float() * get_fp8_descale_factor(dQ_amax, torch_otype) + dK_ref_float = dK_ref.detach().float() * get_fp8_descale_factor(dK_amax, torch_otype) + dV_ref_float = dV_ref.detach().float() * get_fp8_descale_factor(dV_amax, torch_otype) + + if is_ragged: + t_idx_q = sum(seq_len_q_list) + dQ_out[t_idx_q:] = 0 + dQ_ref_float[t_idx_q:] = 0 + t_idx_kv = sum(seq_len_kv_list) + dK_out[t_idx_kv:] = 0 + dK_ref_float[t_idx_kv:] = 0 + dV_out[t_idx_kv:] = 0 + dV_ref_float[t_idx_kv:] = 0 + + atol, rtol = 0.04, 0.2 + torch.testing.assert_close(dQ_out, dQ_ref_float, atol=atol, rtol=rtol) + torch.testing.assert_close(dK_out, dK_ref_float, atol=atol, rtol=rtol) + torch.testing.assert_close(dV_out, dV_ref_float, atol=atol, rtol=rtol) + + if with_sink_token: + torch.testing.assert_close(dSink_token_gpu, dSink_token_ref, atol=0.02, rtol=0.2) # Print hash and stats for determinism verification print_tensor_stats(o_gpu, tag="o_gpu") diff --git a/tools/cudnn_repro/cudnn_repro/log_parser.py b/tools/cudnn_repro/cudnn_repro/log_parser.py index b9ed9093..63cb75a9 100644 --- a/tools/cudnn_repro/cudnn_repro/log_parser.py +++ b/tools/cudnn_repro/cudnn_repro/log_parser.py @@ -1,10 +1,11 @@ """Log reading and JSON context entry extraction.""" +import copy import json import re import sys from pathlib import Path -from typing import Iterable, List, Tuple +from typing import Dict, Iterable, List, Tuple def read_lines(source: str) -> List[str]: @@ -17,7 +18,8 @@ def read_lines(source: str) -> List[str]: return path.read_text().splitlines() -EXECUTE_GRAPH_UID_PATTERN = re.compile(r"Executing graph_uid (\d+)") +EXECUTE_GRAPH_PATTERN = re.compile(r"Executing gid (\d+)") +TENSOR_DUMP_PATTERN = re.compile(r"Tensor Dump uid:\s*(-?\d+).*?Data:\s*(\[.*\])") def _parse_context_entry(line: str) -> Tuple[str, dict] | None: @@ -31,6 +33,38 @@ def _parse_context_entry(line: str) -> Tuple[str, dict] | None: return stripped, payload +def _parse_tensor_dump(line: str) -> Tuple[int, List[int]] | None: + match = TENSOR_DUMP_PATTERN.search(line) + if match is None: + return None + return int(match.group(1)), [int(value) for value in json.loads(match.group(2))] + + +def _apply_tensor_dumps(entry: Tuple[str, dict], tensor_dumps_by_uid: Dict[int, List[int]]) -> Tuple[str, dict]: + if not tensor_dumps_by_uid: + return entry + raw_line, payload = entry + payload = copy.deepcopy(payload) + tensors = payload.get("tensors", []) + tensor_uids = set() + ragged_offset_uids = set() + for tensor in tensors: + uid = tensor.get("uid") + if uid is not None: + uid = int(uid) + tensor_uids.add(uid) + ragged_offset_uid = tensor.get("ragged_offset_uid") + if ragged_offset_uid is not None: + ragged_offset_uids.add(int(ragged_offset_uid)) + if uid is not None and uid in tensor_dumps_by_uid: + tensor["pass_by_value"] = tensor_dumps_by_uid[uid] + + for uid in sorted(ragged_offset_uids - tensor_uids): + if uid in tensor_dumps_by_uid: + tensors.append({"uid": uid, "pass_by_value": tensor_dumps_by_uid[uid]}) + return raw_line, payload + + def iter_graph_entries(lines: Iterable[str]) -> Iterable[Tuple[str, dict]]: """Extract serialized graph JSON entries from log lines.""" for line in lines: @@ -42,25 +76,36 @@ def iter_graph_entries(lines: Iterable[str]) -> Iterable[Tuple[str, dict]]: def iter_context_entries(lines: Iterable[str]) -> Iterable[Tuple[str, dict]]: """Extract execution-linked context entries from log lines. - Prefer execution order when `Executing graph_uid ...` markers are present. + Prefer execution order when `Executing gid ...` markers are present. Fall back to serialized graph order for older logs. """ graph_entries = list(iter_graph_entries(lines)) - graph_entries_by_uid = {} + graph_entries_by_gid = {} for raw_line, payload in graph_entries: - graph_uid = payload.get("graph_uid") - if graph_uid is not None: - graph_entries_by_uid[int(graph_uid)] = (raw_line, payload) + gid = payload.get("gid") + if gid is not None: + graph_entries_by_gid[int(gid)] = (raw_line, payload) execution_entries = [] + current_entry = None + current_dumps = {} for line in lines: - match = EXECUTE_GRAPH_UID_PATTERN.search(line) - if match is None: + match = EXECUTE_GRAPH_PATTERN.search(line) + if match is not None: + if current_entry is not None: + execution_entries.append(_apply_tensor_dumps(current_entry, current_dumps)) + current_dumps = {} + gid = int(match.group(1)) + current_entry = graph_entries_by_gid.get(gid) continue - graph_uid = int(match.group(1)) - entry = graph_entries_by_uid.get(graph_uid) - if entry is not None: - execution_entries.append(entry) + + dump = _parse_tensor_dump(line) + if dump is not None: + uid, values = dump + current_dumps[uid] = values + + if current_entry is not None: + execution_entries.append(_apply_tensor_dumps(current_entry, current_dumps)) if execution_entries: yield from execution_entries diff --git a/tools/cudnn_repro/cudnn_repro/operations.py b/tools/cudnn_repro/cudnn_repro/operations.py index 50ad03ba..a3c2de6a 100644 --- a/tools/cudnn_repro/cudnn_repro/operations.py +++ b/tools/cudnn_repro/cudnn_repro/operations.py @@ -4,10 +4,12 @@ from . import sdpa_fp8_bwd from . import sdpa_fp8_fwd from . import sdpa_fwd +from . import utils def detect_operation_key(payload: dict) -> str: """Detect the operation key from the JSON payload.""" + utils.validate_payload_version(payload) for node in payload.get("nodes", []): tag = node.get("tag", "") if tag in ("SDPA_FP8_FWD", "SDPA_MXFP8_FWD"): diff --git a/tools/cudnn_repro/cudnn_repro/sdpa_bwd.py b/tools/cudnn_repro/cudnn_repro/sdpa_bwd.py index 525b358a..6c45bd79 100644 --- a/tools/cudnn_repro/cudnn_repro/sdpa_bwd.py +++ b/tools/cudnn_repro/cudnn_repro/sdpa_bwd.py @@ -20,16 +20,9 @@ def _find_bwd_node(payload: dict) -> dict: def _unsupported_features(payload: dict, node: dict) -> list[str]: - tensors = payload.get("tensors", {}) - node_name = node.get("name") inputs = node.get("inputs", {}) - ragged_q_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_Q", inputs.get("RAGGED_OFFSET_Q")) - ragged_kv_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_KV", inputs.get("RAGGED_OFFSET_KV")) - unsupported = [] - if ragged_q_entry is not None or ragged_kv_entry is not None: - unsupported.append("ragged") if any("PAGED_ATTENTION" in key for key in inputs): unsupported.append("paged_attention") return unsupported @@ -39,7 +32,24 @@ def _as_forward_payload(payload: dict, node: dict) -> dict: inputs = node.get("inputs", {}) fwd_node = dict(node) fwd_node["tag"] = "SDPA_FWD" - fwd_node["inputs"] = {name: inputs[name] for name in ("Q", "K", "V", "SEQ_LEN_Q", "SEQ_LEN_KV", "SINK_TOKEN", "BIAS", "BLOCK_MASK") if name in inputs} + fwd_node["inputs"] = { + name: inputs[name] + for name in ( + "Q", + "K", + "V", + "SEQ_LEN_Q", + "SEQ_LEN_KV", + "RAGGED_OFFSET_Q", + "RAGGED_OFFSET_KV", + "RAGGED_OFFSETS_Q", + "RAGGED_OFFSETS_KV", + "SINK_TOKEN", + "BIAS", + "BLOCK_MASK", + ) + if name in inputs + } fwd_node["outputs"] = {"O": inputs.get("O")} fwd_node["generate_stats"] = True @@ -50,6 +60,7 @@ def _as_forward_payload(payload: dict, node: dict) -> dict: def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build test configuration from JSON payload for simple SDPA backward.""" + utils.validate_payload_version(payload) node = _find_bwd_node(payload) unsupported = _unsupported_features(payload, node) if unsupported: @@ -57,7 +68,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: raise NotImplementedError(f"Simple SDPA backward repro does not yet support: {joined}") cfg = sdpa_fwd.build_cfg(raw_line, _as_forward_payload(payload, node), seed) - stats_entry = utils.tensor_entry(payload.get("tensors", {}), node.get("name"), "Stats", node.get("inputs", {}).get("Stats")) + stats_entry = utils.tensor_entry(payload.get("tensors", []), node.get("inputs", {}).get("Stats")) cfg["is_determin"] = bool(node.get("is_deterministic_algorithm", False)) cfg["shape_stats"] = utils.shape(stats_entry) cfg["stride_stats"] = utils.stride(stats_entry) @@ -68,14 +79,8 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: def extract_seq_and_ragged(payload: dict, seed: int) -> dict: """Extract sequence lengths and ragged offsets for simple backward.""" - _find_bwd_node(payload) - return { - "seq_len_q": [], - "seq_len_kv": [], - "ragged_offset_q": [], - "ragged_offset_kv": [], - "rng_data_seed": seed, - } + node = _find_bwd_node(payload) + return sdpa_fwd.extract_seq_and_ragged(_as_forward_payload(payload, node), seed) def extract_and_annotate(raw_line: str, payload: dict, full_log_text: Optional[str] = None) -> dict: diff --git a/tools/cudnn_repro/cudnn_repro/sdpa_fp8_bwd.py b/tools/cudnn_repro/cudnn_repro/sdpa_fp8_bwd.py index 07682cec..1162a2c7 100644 --- a/tools/cudnn_repro/cudnn_repro/sdpa_fp8_bwd.py +++ b/tools/cudnn_repro/cudnn_repro/sdpa_fp8_bwd.py @@ -16,25 +16,25 @@ def _find_node(payload: dict) -> dict: def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build FP8 backward test configuration from JSON payload.""" + utils.validate_payload_version(payload) node = _find_node(payload) is_mxfp8 = utils.is_mxfp8_payload(payload, node) - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) outputs = node.get("outputs", {}) - q_entry = utils.tensor_entry(tensors, node_name, "Q", inputs.get("Q")) - k_entry = utils.tensor_entry(tensors, node_name, "K", inputs.get("K")) - v_entry = utils.tensor_entry(tensors, node_name, "V", inputs.get("V")) - o_entry = utils.tensor_entry(tensors, node_name, "O", inputs.get("O")) - stats_entry = utils.tensor_entry(tensors, node_name, "Stats", inputs.get("Stats")) - dq_entry = utils.tensor_entry(tensors, node_name, "dQ", outputs.get("dQ")) - dk_entry = utils.tensor_entry(tensors, node_name, "dK", outputs.get("dK")) - dv_entry = utils.tensor_entry(tensors, node_name, "dV", outputs.get("dV")) - seq_q_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q")) - seq_kv_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV")) - page_table_k_entry = utils.tensor_entry(tensors, node_name, "Page_table_K", inputs.get("Page_table_K")) + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + o_entry = utils.tensor_entry(tensors, inputs.get("O")) + stats_entry = utils.tensor_entry(tensors, inputs.get("Stats")) + dq_entry = utils.tensor_entry(tensors, outputs.get("dQ")) + dk_entry = utils.tensor_entry(tensors, outputs.get("dK")) + dv_entry = utils.tensor_entry(tensors, outputs.get("dV")) + seq_q_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q")) + seq_kv_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV")) + page_table_k_entry = utils.tensor_entry(tensors, inputs.get("Page_table_K")) output_dtypes = {dtype for dtype in (utils.tensor_dtype(dq_entry), utils.tensor_dtype(dk_entry), utils.tensor_dtype(dv_entry)) if dtype is not None} if len(output_dtypes) > 1: @@ -57,12 +57,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: seq_len_kv = utils.seq_len(seq_kv_entry) is_paged = any(label.startswith("Page_table_") or "PAGED_ATTENTION" in label for label in inputs) - repro_metadata = payload.get("repro_metadata", {}) - ragged_tensor_names = set(repro_metadata.get("ragged_tensor_names", [])) - is_ragged = any( - entry is not None and (utils.parse_optional_int(entry.get("ragged_offset_uid")) is not None or entry.get("name") in ragged_tensor_names) - for entry in (q_entry, k_entry, v_entry, o_entry, dq_entry, dk_entry, dv_entry) - ) + is_ragged = utils.is_ragged_payload(inputs, (q_entry, k_entry, v_entry, o_entry, dq_entry, dk_entry, dv_entry), payload) batches = shape_q[0] if shape_q else None h_q = shape_q[1] if shape_q else None @@ -89,7 +84,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: cfg["is_paged"] = is_paged cfg["is_bias"] = utils.bool_from_inputs(inputs, "BIAS") cfg["is_block_mask"] = utils.bool_from_inputs(inputs, "BLOCK_MASK") - cfg["is_padding"] = node.get("padding_mask") or bool(seq_len_q or seq_len_kv) + cfg["is_padding"] = is_ragged or node.get("padding_mask") or bool(seq_len_q or seq_len_kv) cfg["is_ragged"] = is_ragged cfg["is_dropout"] = dropout_prob > 0.0 cfg["is_determin"] = bool(node.get("is_deterministic_algorithm", False)) @@ -146,14 +141,20 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: def extract_seq_and_ragged(payload: dict, seed: int) -> dict: """Extract sequence lengths and ragged offsets from an FP8 backward payload.""" node = _find_node(payload) - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) + ragged_offset_q = inputs.get("RAGGED_OFFSET_Q") or inputs.get("RAGGED_OFFSETS_Q") + ragged_offset_kv = inputs.get("RAGGED_OFFSET_KV") or inputs.get("RAGGED_OFFSETS_KV") + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + ragged_q_entry = utils.tensor_entry(tensors, ragged_offset_q) or utils.ragged_offset_entry(tensors, q_entry) + ragged_kv_entry = utils.tensor_entry(tensors, ragged_offset_kv) or utils.ragged_offset_entry(tensors, k_entry, v_entry) return { - "seq_len_q": utils.seq_len(utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q"))), - "seq_len_kv": utils.seq_len(utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV"))), - "ragged_offset_q": utils.seq_len(utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_Q", inputs.get("RAGGED_OFFSET_Q"))), - "ragged_offset_kv": utils.seq_len(utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_KV", inputs.get("RAGGED_OFFSET_KV"))), + "seq_len_q": utils.seq_len(utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q"))), + "seq_len_kv": utils.seq_len(utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV"))), + "ragged_offset_q": utils.seq_len(ragged_q_entry), + "ragged_offset_kv": utils.seq_len(ragged_kv_entry), "rng_data_seed": seed, } diff --git a/tools/cudnn_repro/cudnn_repro/sdpa_fp8_fwd.py b/tools/cudnn_repro/cudnn_repro/sdpa_fp8_fwd.py index 68218c1a..371fbc87 100644 --- a/tools/cudnn_repro/cudnn_repro/sdpa_fp8_fwd.py +++ b/tools/cudnn_repro/cudnn_repro/sdpa_fp8_fwd.py @@ -16,22 +16,22 @@ def _find_node(payload: dict) -> dict: def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build FP8 forward test configuration from JSON payload.""" + utils.validate_payload_version(payload) node = _find_node(payload) is_mxfp8 = utils.is_mxfp8_payload(payload, node) - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) outputs = node.get("outputs", {}) - q_entry = utils.tensor_entry(tensors, node_name, "Q", inputs.get("Q")) - k_entry = utils.tensor_entry(tensors, node_name, "K", inputs.get("K")) - v_entry = utils.tensor_entry(tensors, node_name, "V", inputs.get("V")) - o_entry = utils.tensor_entry(tensors, node_name, "O", outputs.get("O")) - stats_entry = utils.tensor_entry(tensors, node_name, "Stats", outputs.get("Stats")) - seq_q_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q")) - seq_kv_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV")) - page_table_k_entry = utils.tensor_entry(tensors, node_name, "Page_table_K", inputs.get("Page_table_K")) + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + o_entry = utils.tensor_entry(tensors, outputs.get("O")) + stats_entry = utils.tensor_entry(tensors, outputs.get("Stats")) + seq_q_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q")) + seq_kv_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV")) + page_table_k_entry = utils.tensor_entry(tensors, inputs.get("Page_table_K")) shape_q = utils.shape(q_entry) shape_k = utils.shape(k_entry) @@ -69,12 +69,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: diag_align_map = {"TOP_LEFT": 0, "BOTTOM_RIGHT": 1} diag_align = diag_align_map.get(node.get("diagonal_alignment", "TOP_LEFT"), 0) dropout_prob = utils.parse_hex_float(node.get("dropout_probability")) or 0.0 - repro_metadata = payload.get("repro_metadata", {}) - ragged_tensor_names = set(repro_metadata.get("ragged_tensor_names", [])) - is_ragged = any( - entry is not None and (utils.parse_optional_int(entry.get("ragged_offset_uid")) is not None or entry.get("name") in ragged_tensor_names) - for entry in (q_entry, k_entry, v_entry, o_entry) - ) + is_ragged = utils.is_ragged_payload(inputs, (q_entry, k_entry, v_entry, o_entry), payload) block_size = (shape_k[2] if shape_k and len(shape_k) > 2 else None) if is_paged else None if block_size is None and is_paged: @@ -89,7 +84,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: cfg["is_paged"] = is_paged cfg["is_bias"] = utils.bool_from_inputs(inputs, "BIAS") cfg["is_block_mask"] = utils.bool_from_inputs(inputs, "BLOCK_MASK") - cfg["is_padding"] = node.get("padding_mask") or bool(seq_len_q or seq_len_kv) + cfg["is_padding"] = is_ragged or node.get("padding_mask") or bool(seq_len_q or seq_len_kv) cfg["is_ragged"] = is_ragged cfg["is_dropout"] = dropout_prob > 0.0 cfg["is_determin"] = None @@ -148,14 +143,20 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: def extract_seq_and_ragged(payload: dict, seed: int) -> dict: """Extract sequence lengths and ragged offsets from an FP8 forward payload.""" node = _find_node(payload) - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) + ragged_offset_q = inputs.get("RAGGED_OFFSET_Q") or inputs.get("RAGGED_OFFSETS_Q") + ragged_offset_kv = inputs.get("RAGGED_OFFSET_KV") or inputs.get("RAGGED_OFFSETS_KV") + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + ragged_q_entry = utils.tensor_entry(tensors, ragged_offset_q) or utils.ragged_offset_entry(tensors, q_entry) + ragged_kv_entry = utils.tensor_entry(tensors, ragged_offset_kv) or utils.ragged_offset_entry(tensors, k_entry, v_entry) return { - "seq_len_q": utils.seq_len(utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q"))), - "seq_len_kv": utils.seq_len(utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV"))), - "ragged_offset_q": utils.seq_len(utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_Q", inputs.get("RAGGED_OFFSET_Q"))), - "ragged_offset_kv": utils.seq_len(utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_KV", inputs.get("RAGGED_OFFSET_KV"))), + "seq_len_q": utils.seq_len(utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q"))), + "seq_len_kv": utils.seq_len(utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV"))), + "ragged_offset_q": utils.seq_len(ragged_q_entry), + "ragged_offset_kv": utils.seq_len(ragged_kv_entry), "rng_data_seed": seed, } diff --git a/tools/cudnn_repro/cudnn_repro/sdpa_fwd.py b/tools/cudnn_repro/cudnn_repro/sdpa_fwd.py index 1bc073f9..1b4e61f4 100644 --- a/tools/cudnn_repro/cudnn_repro/sdpa_fwd.py +++ b/tools/cudnn_repro/cudnn_repro/sdpa_fwd.py @@ -9,6 +9,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build test configuration from JSON payload.""" + utils.validate_payload_version(payload) node = None for candidate in payload.get("nodes", []): if candidate.get("tag") == "SDPA_FWD": @@ -19,18 +20,17 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: if node is None: raise ValueError("SDPA node not found in log") - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) outputs = node.get("outputs", {}) - q_entry = utils.tensor_entry(tensors, node_name, "Q", inputs.get("Q")) - k_entry = utils.tensor_entry(tensors, node_name, "K", inputs.get("K")) - v_entry = utils.tensor_entry(tensors, node_name, "V", inputs.get("V")) - o_entry = utils.tensor_entry(tensors, node_name, "O", outputs.get("O")) + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + o_entry = utils.tensor_entry(tensors, outputs.get("O")) - seq_q_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q")) - seq_kv_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV")) + seq_q_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q")) + seq_kv_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV")) shape_q = utils.shape(q_entry) shape_k = utils.shape(k_entry) @@ -64,12 +64,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: diag_align = diag_align_map.get(node.get("diagonal_alignment", "TOP_LEFT"), 0) dropout_prob = utils.parse_hex_float(node.get("dropout_probability")) or 0.0 - repro_metadata = payload.get("repro_metadata", {}) - ragged_tensor_names = set(repro_metadata.get("ragged_tensor_names", [])) - is_ragged = any( - entry is not None and (utils.parse_optional_int(entry.get("ragged_offset_uid")) is not None or entry.get("name") in ragged_tensor_names) - for entry in (q_entry, k_entry, v_entry, o_entry) - ) + is_ragged = utils.is_ragged_payload(inputs, (q_entry, k_entry, v_entry, o_entry), payload) cfg = OrderedDict() cfg["data_type"] = utils.torch_dtype(payload.get("context", {}).get("io_data_type")) @@ -79,7 +74,7 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: cfg["is_paged"] = any("PAGED_ATTENTION" in key for key in inputs) cfg["is_bias"] = utils.bool_from_inputs(inputs, "BIAS") cfg["is_block_mask"] = utils.bool_from_inputs(inputs, "BLOCK_MASK") - cfg["is_padding"] = node.get("padding_mask") or bool(seq_len_q or seq_len_kv) + cfg["is_padding"] = is_ragged or node.get("padding_mask") or bool(seq_len_q or seq_len_kv) cfg["is_ragged"] = is_ragged cfg["is_dropout"] = dropout_prob > 0.0 cfg["is_determin"] = None @@ -149,19 +144,25 @@ def extract_seq_and_ragged(payload: dict, seed: int) -> dict: "ragged_offset_kv": [], } - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) inputs = node.get("inputs", {}) - seq_q_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_Q", inputs.get("SEQ_LEN_Q")) - seq_kv_entry = utils.tensor_entry(tensors, node_name, "SEQ_LEN_KV", inputs.get("SEQ_LEN_KV")) - ragged_q_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_Q", inputs.get("RAGGED_OFFSET_Q")) - ragged_kv_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSET_KV", inputs.get("RAGGED_OFFSET_KV")) + seq_q_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_Q")) + seq_kv_entry = utils.tensor_entry(tensors, inputs.get("SEQ_LEN_KV")) + ragged_q_entry = utils.tensor_entry(tensors, inputs.get("RAGGED_OFFSET_Q")) + ragged_kv_entry = utils.tensor_entry(tensors, inputs.get("RAGGED_OFFSET_KV")) if ragged_q_entry is None: - ragged_q_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSETS_Q", inputs.get("RAGGED_OFFSETS_Q")) + ragged_q_entry = utils.tensor_entry(tensors, inputs.get("RAGGED_OFFSETS_Q")) if ragged_kv_entry is None: - ragged_kv_entry = utils.tensor_entry(tensors, node_name, "RAGGED_OFFSETS_KV", inputs.get("RAGGED_OFFSETS_KV")) + ragged_kv_entry = utils.tensor_entry(tensors, inputs.get("RAGGED_OFFSETS_KV")) + if ragged_q_entry is None: + q_entry = utils.tensor_entry(tensors, inputs.get("Q")) + ragged_q_entry = utils.ragged_offset_entry(tensors, q_entry) + if ragged_kv_entry is None: + k_entry = utils.tensor_entry(tensors, inputs.get("K")) + v_entry = utils.tensor_entry(tensors, inputs.get("V")) + ragged_kv_entry = utils.ragged_offset_entry(tensors, k_entry, v_entry) return { "seq_len_q": utils.seq_len(seq_q_entry), diff --git a/tools/cudnn_repro/cudnn_repro/utils.py b/tools/cudnn_repro/cudnn_repro/utils.py index 73bf8314..92bc9467 100644 --- a/tools/cudnn_repro/cudnn_repro/utils.py +++ b/tools/cudnn_repro/cudnn_repro/utils.py @@ -2,12 +2,20 @@ import hashlib import json -import re import struct import sys from pathlib import Path from typing import Any, Optional, Tuple +GRAPH_JSON_VERSION = "2.0" +RAGGED_OFFSET_INPUTS = ("RAGGED_OFFSET_Q", "RAGGED_OFFSET_KV", "RAGGED_OFFSETS_Q", "RAGGED_OFFSETS_KV") + + +def validate_payload_version(payload: dict) -> None: + version = payload.get("json_version") + if version != GRAPH_JSON_VERSION: + raise ValueError(f"Unsupported graph JSON version. Expected {GRAPH_JSON_VERSION}, got {version!r}.") + def sha1_seed(raw: str) -> int: """Generate a deterministic seed from a string using SHA1.""" @@ -76,48 +84,14 @@ def parse_optional_int(value: Any) -> Optional[int]: return None -def tensor_entry(tensors: dict, node_name: Optional[str], label: str, hint: Optional[str]) -> Optional[dict]: - """Find a tensor entry in the tensors dict by various lookup strategies.""" - if not tensors: - return None - - def _from_key(key: Any) -> Optional[dict]: - if key is None: - return None - str_key = str(int(key)) if isinstance(key, (int, float)) else str(key) - return tensors.get(str_key) - - def _from_uid(uid: Any) -> Optional[dict]: - try: - uid_int = int(uid) if uid is not None else None - except (TypeError, ValueError): - return None - for value in tensors.values(): - if value.get("uid") == uid_int: - return value +def tensor_entry(tensors: list, hint: Optional[str]) -> Optional[dict]: + """Find a tensor entry by uid reference.""" + tensor_uid = parse_optional_int(hint) + if tensor_uid is None or not isinstance(tensors, list): return None - - candidates = [] - if hint: - candidates.append(hint) - candidates.append(str(hint)) - if node_name: - candidates.append(f"{node_name}::{label}") - candidates.append(f"{node_name}::{label.lower()}") - candidates.append(f"{node_name}::{label.upper()}") - candidates.extend([label, label.lower(), label.upper()]) - for key in candidates: - entry = _from_key(key) - if entry: + for entry in tensors: + if isinstance(entry, dict) and parse_optional_int(entry.get("uid")) == tensor_uid: return entry - direct_uid = _from_uid(hint) - if direct_uid: - return direct_uid - suffix = f"::{label}" - for key, value in tensors.items(): - skey = str(key) - if skey.endswith(suffix) or skey == label: - return value return None @@ -192,6 +166,32 @@ def bool_from_inputs(inputs: dict, target: str) -> Optional[bool]: return target in inputs +def has_ragged_offset_inputs(inputs: dict) -> bool: + return any(name in inputs for name in RAGGED_OFFSET_INPUTS) + + +def ragged_offset_entry(tensors: list, *entries: Optional[dict]) -> Optional[dict]: + for entry in entries: + if entry is None: + continue + tensor = tensor_entry(tensors, entry.get("ragged_offset_uid")) + if tensor is not None: + return tensor + return None + + +def is_ragged_payload(inputs: dict, entries: tuple[Optional[dict], ...], payload: dict) -> bool: + if has_ragged_offset_inputs(inputs): + return True + if any(entry is not None and entry.get("ragged_offset_uid") is not None for entry in entries): + return True + repro_metadata = payload.get("repro_metadata", {}) + if repro_metadata.get("ragged_offset_q") or repro_metadata.get("ragged_offset_kv"): + return True + ragged_tensor_names = set(repro_metadata.get("ragged_tensor_names", [])) + return any(entry is not None and entry.get("name") in ragged_tensor_names for entry in entries) + + def infer_block_size(page_table_entry: Optional[dict], seq_len_kv: list[int], k_entry: Optional[dict]) -> Optional[int]: """Infer paged-attention block size from serialized tensors.""" if page_table_entry is None or k_entry is None: @@ -217,12 +217,11 @@ def is_mxfp8_payload(payload: dict, node: dict) -> bool: if node.get("is_mxfp8") is True: return True - tensors = payload.get("tensors", {}) - node_name = node.get("name") + tensors = payload.get("tensors", []) for label, hint in node.get("inputs", {}).items(): if not label.startswith(("Descale_", "Scale_")): continue - entry = tensor_entry(tensors, node_name, label, hint) + entry = tensor_entry(tensors, hint) if entry is None: continue data_type = (entry.get("data_type") or "").upper() @@ -232,39 +231,9 @@ def is_mxfp8_payload(payload: dict, node: dict) -> bool: return False -def parse_ragged_tensor_names(log_text: Optional[str]) -> list[str]: - """Extract tensor names that have ragged offsets enabled from FE log text.""" - if not log_text: - return [] - - uid_to_name = {} - current_descriptor_uid = None - ragged_tensor_names = [] - - for line in log_text.splitlines(): - match = re.search(r"Backend Tensor named '([^']+)' with UID (\d+)", line) - if match: - uid_to_name[int(match.group(2))] = match.group(1) - continue - - match = re.search(r"Id:\s+(\d+)", line) - if match: - current_descriptor_uid = int(match.group(1)) - continue - - if current_descriptor_uid is not None and "raggedOffset: Enabled UID:" in line: - tensor_name = uid_to_name.get(current_descriptor_uid) - if tensor_name is not None: - ragged_tensor_names.append(tensor_name) - - return ragged_tensor_names - - def add_ragged_tensor_names(payload: dict, log_text: Optional[str]) -> None: - ragged_uids = (entry.get("ragged_offset_uid") for entry in payload.get("tensors", {}).values()) - has_ragged_uids = any(parse_optional_int(uid) is not None for uid in ragged_uids) - names = [] if payload.get("graph_uid") is not None or has_ragged_uids else parse_ragged_tensor_names(log_text) - payload["repro_metadata"]["ragged_tensor_names"] = names + del log_text + payload["repro_metadata"]["ragged_tensor_names"] = [] def json_with_max_indent(value: Any, depth: int = 0, indent: int = 2, max_indent_level: int = 3) -> str: diff --git a/tools/cudnn_repro/tests/helpers.py b/tools/cudnn_repro/tests/helpers.py new file mode 100644 index 00000000..b92a7007 --- /dev/null +++ b/tools/cudnn_repro/tests/helpers.py @@ -0,0 +1,2 @@ +def tensor_list(tensors): + return [entry for _, entry in sorted(tensors.items(), key=lambda item: int(item[0]))] diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_bwd.py b/tools/cudnn_repro/tests/test_cudnn_repro_bwd.py index e169b9f8..5e8ea78c 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_bwd.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_bwd.py @@ -1,18 +1,20 @@ -import pytest - import cudnn_repro.repro_command as repro_command import cudnn_repro.sdpa_bwd as sdpa_bwd +from .helpers import tensor_list + def test_build_bwd_cfg_simple_case(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "BFLOAT16"}, "nodes": [ { "tag": "SDPA_BWD", "name": "sdpa_backward", - "inputs": {"Q": 0, "K": 1, "V": 2, "O": 3, "Stats": 4, "dO": 5}, - "outputs": {"dQ": 6, "dK": 7, "dV": 8}, + "inputs": {"Q": 1, "K": 2, "V": 3, "O": 4, "Stats": 5, "dO": 6}, + "outputs": {"dQ": 7, "dK": 8, "dV": 9}, "diagonal_alignment": "TOP_LEFT", "is_deterministic_algorithm": False, "left_bound": None, @@ -20,17 +22,19 @@ def test_build_bwd_cfg_simple_case(): "padding_mask": False, } ], - "tensors": { - "0": {"uid": 0, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, - "5": {"uid": 5, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "4": {"uid": 4, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "5": {"uid": 5, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, + "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "9": {"uid": 9, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + } + ), } cfg = sdpa_bwd.build_cfg("{}", payload, seed=123) @@ -49,29 +53,33 @@ def test_build_bwd_cfg_simple_case(): def test_build_bwd_cfg_preserves_rope(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "BFLOAT16"}, "nodes": [ { "tag": "SDPA_BWD", "name": "sdpa_backward", - "inputs": {"Q": 0, "K": 1, "V": 2, "O": 3, "Stats": 4, "dO": 5}, - "outputs": {"dQ": 6, "dK": 7, "dV": 8}, + "inputs": {"Q": 1, "K": 2, "V": 3, "O": 4, "Stats": 5, "dO": 6}, + "outputs": {"dQ": 7, "dK": 8, "dV": 9}, "diagonal_alignment": "TOP_LEFT", "padding_mask": False, }, {"tag": "ROPE_BWD", "name": "RoPE_BWD_Q"}, ], - "tensors": { - "0": {"uid": 0, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, - "5": {"uid": 5, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "4": {"uid": 4, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "5": {"uid": 5, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, + "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "9": {"uid": 9, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + } + ), } cfg = sdpa_bwd.build_cfg("{}", payload, seed=123) @@ -81,57 +89,80 @@ def test_build_bwd_cfg_preserves_rope(): assert "'with_rope': True" in command -def test_build_bwd_cfg_rejects_padding(): +def test_build_bwd_cfg_supports_ragged(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "FLOAT16"}, "nodes": [ { "tag": "SDPA_BWD", "name": "sdpa_backward", - "inputs": {"Q": 0, "K": 1, "V": 2, "O": 3, "Stats": 4, "dO": 5, "RAGGED_OFFSET_Q": 9, "RAGGED_OFFSET_KV": 10}, - "outputs": {"dQ": 6, "dK": 7, "dV": 8}, + "inputs": { + "Q": 1, + "K": 2, + "V": 3, + "O": 4, + "Stats": 5, + "dO": 6, + "SEQ_LEN_Q": 10, + "SEQ_LEN_KV": 11, + "RAGGED_OFFSET_Q": 12, + "RAGGED_OFFSET_KV": 13, + }, + "outputs": {"dQ": 7, "dK": 8, "dV": 9}, "diagonal_alignment": "TOP_LEFT", - "padding_mask": False, + "padding_mask": True, } ], - "tensors": { - "0": {"uid": 0, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, - "5": {"uid": 5, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "9": {"uid": 9, "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1]}, - "10": {"uid": 10, "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "2": {"uid": 2, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "4": {"uid": 4, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "5": {"uid": 5, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, + "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "8": {"uid": 8, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "9": {"uid": 9, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "10": {"uid": 10, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [13, 11]}, + "11": {"uid": 11, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [15, 9]}, + "12": {"uid": 12, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 13, 24]}, + "13": {"uid": 13, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 15, 24]}, + } + ), } - with pytest.raises(NotImplementedError, match="ragged"): - sdpa_bwd.build_cfg("{}", payload, seed=123) + cfg = sdpa_bwd.build_cfg("{}", payload, seed=123) + + assert cfg["is_ragged"] is True + assert cfg["is_padding"] is True + assert cfg["seq_len_q"] == [13, 11] + assert cfg["seq_len_kv"] == [15, 9] def test_build_bwd_cfg_supports_padding_sink_and_sliding_window(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "BFLOAT16"}, "nodes": [ { "tag": "SDPA_BWD", "name": "sdpa_backward", "inputs": { - "Q": 0, - "K": 1, - "V": 2, - "O": 3, - "Stats": 4, - "dO": 5, - "SEQ_LEN_Q": 9, - "SEQ_LEN_KV": 10, - "SINK_TOKEN": 11, + "Q": 1, + "K": 2, + "V": 3, + "O": 4, + "Stats": 5, + "dO": 6, + "SEQ_LEN_Q": 10, + "SEQ_LEN_KV": 11, + "SINK_TOKEN": 12, }, - "outputs": {"dQ": 6, "dK": 7, "dV": 8, "DSINK_TOKEN": 12}, + "outputs": {"dQ": 7, "dK": 8, "dV": 9, "DSINK_TOKEN": 13}, "diagonal_alignment": "BOTTOM_RIGHT", "is_deterministic_algorithm": True, "left_bound": 8, @@ -139,21 +170,23 @@ def test_build_bwd_cfg_supports_padding_sink_and_sliding_window(): "padding_mask": True, } ], - "tensors": { - "0": {"uid": 0, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "1": {"uid": 1, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, - "2": {"uid": 2, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, - "5": {"uid": 5, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, - "7": {"uid": 7, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, - "8": {"uid": 8, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, - "9": {"uid": 9, "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1]}, - "10": {"uid": 10, "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1]}, - "11": {"uid": 11, "dim": [1, 4, 1, 1], "stride": [4, 1, 1, 1]}, - "12": {"uid": 12, "dim": [1, 4, 1, 1], "stride": [4, 1, 1, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "2": {"uid": 2, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, + "3": {"uid": 3, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, + "4": {"uid": 4, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "5": {"uid": 5, "dim": [2, 4, 16, 1], "stride": [64, 16, 1, 1]}, + "6": {"uid": 6, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "7": {"uid": 7, "dim": [2, 4, 16, 64], "stride": [4096, 1024, 64, 1]}, + "8": {"uid": 8, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, + "9": {"uid": 9, "dim": [2, 1, 16, 64], "stride": [1024, 65536, 64, 1]}, + "10": {"uid": 10, "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1]}, + "11": {"uid": 11, "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1]}, + "12": {"uid": 12, "dim": [1, 4, 1, 1], "stride": [4, 1, 1, 1]}, + "13": {"uid": 13, "dim": [1, 4, 1, 1], "stride": [4, 1, 1, 1]}, + } + ), } cfg = sdpa_bwd.build_cfg("{}", payload, seed=123) diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_cli.py b/tools/cudnn_repro/tests/test_cudnn_repro_cli.py index 930913ce..1833cfa1 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_cli.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_cli.py @@ -4,13 +4,16 @@ import sys from pathlib import Path +from .helpers import tensor_list + PACKAGE_ROOT = Path(__file__).resolve().parents[1] -def fwd_payload(graph_uid, diagonal_alignment): +def fwd_payload(gid, diagonal_alignment): return { + "json_version": "2.0", "context": {"io_data_type": "FLOAT16"}, - "graph_uid": graph_uid, + "gid": gid, "nodes": [ { "tag": "SDPA_FWD", @@ -24,12 +27,14 @@ def fwd_payload(graph_uid, diagonal_alignment): "padding_mask": False, } ], - "tensors": { - "1": {"uid": 1, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, - "2": {"uid": 2, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, - "3": {"uid": 3, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, - "4": {"uid": 4, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, + "2": {"uid": 2, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, + "3": {"uid": 3, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, + "4": {"uid": 4, "dim": [1, 2, 16, 64], "stride": [2048, 1024, 64, 1]}, + } + ), } @@ -86,7 +91,7 @@ def test_cli_debug_writes_default_files(tmp_path): payload = json.loads((tmp_path / "cudnn_repro_payload.json").read_text()) assert (tmp_path / "cudnn_repro_log.txt").read_text().splitlines() == log_path.read_text().splitlines() assert "test/python/test_mhas_v2.py::test_repro" in (tmp_path / "cudnn_repro_command.txt").read_text() - assert payload["graph_uid"] == 22 + assert payload["gid"] == 22 assert "repro_metadata" in payload @@ -95,8 +100,8 @@ def test_cli_debug_all_writes_indexed_files(tmp_path): run_cli(tmp_path, "--all", str(log_path), debug=True) - for idx, graph_uid in enumerate((11, 22)): + for idx, gid in enumerate((11, 22)): payload = json.loads((tmp_path / f"cudnn_repro_payload_{idx}.json").read_text()) assert (tmp_path / f"cudnn_repro_log_{idx}.txt").read_text().splitlines() == log_path.read_text().splitlines() assert "test/python/test_mhas_v2.py::test_repro" in (tmp_path / f"cudnn_repro_command_{idx}.txt").read_text() - assert payload["graph_uid"] == graph_uid + assert payload["gid"] == gid diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py b/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py index 454ec958..69a02020 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py @@ -50,15 +50,16 @@ def _normalize_tensor_entry(entry): def _normalize_payload(payload): - tensors = payload.get("tensors", {}) + tensors = payload.get("tensors", []) + tensor_by_uid = {int(entry["uid"]): entry for entry in tensors} def resolve(uid): - return _normalize_tensor_entry(tensors[str(uid)]) + return _normalize_tensor_entry(tensor_by_uid[int(uid)]) normalized = { "context": payload.get("context"), "nodes": [], - "tensors": sorted(json.dumps(_normalize_tensor_entry(entry), sort_keys=True) for entry in tensors.values()), + "tensors": sorted(json.dumps(_normalize_tensor_entry(entry), sort_keys=True) for entry in tensors), } for node in payload.get("nodes", []): normalized_node = {} diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py b/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py index aa334dea..febaab18 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py @@ -5,6 +5,8 @@ import cudnn_repro.sdpa_fp8_bwd as sdpa_fp8_bwd import cudnn_repro.sdpa_fp8_fwd as sdpa_fp8_fwd +from .helpers import tensor_list + def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dtype="HALF", mxfp8=False): if mxfp8 and tag == "SDPA_FP8_FWD": @@ -23,6 +25,8 @@ def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dt if ragged: inputs["SEQ_LEN_Q"] = 13 inputs["SEQ_LEN_KV"] = 14 + inputs["RAGGED_OFFSET_Q"] = 17 + inputs["RAGGED_OFFSET_KV"] = 18 if paged: inputs["SEQ_LEN_Q"] = 13 inputs["SEQ_LEN_KV"] = 14 @@ -58,9 +62,11 @@ def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dt tensors["15"] = {"uid": 15, "data_type": "INT32", "dim": [2, 1, 5, 1], "stride": [5, 5, 1, 1]} tensors["16"] = {"uid": 16, "data_type": "INT32", "dim": [2, 1, 5, 1], "stride": [5, 5, 1, 1]} if ragged: - for key in ("1", "2", "3", "4"): - tensors[key]["ragged_offset_uid"] = 99 + tensors["17"] = {"uid": 17, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 13, 24]} + tensors["18"] = {"uid": 18, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 19, 36]} return { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "FP8_E4M3"}, "nodes": [ { @@ -78,7 +84,7 @@ def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dt "right_bound": None, } ], - "tensors": tensors, + "tensors": tensor_list(tensors), "repro_metadata": {"ragged_tensor_names": [""] if ragged else []}, } @@ -107,7 +113,11 @@ def _fp8_bwd_payload(*, ragged=False, output_dtype="HALF", mxfp8=False): if ragged: inputs["SEQ_LEN_Q"] = 19 inputs["SEQ_LEN_KV"] = 20 + inputs["RAGGED_OFFSET_Q"] = 28 + inputs["RAGGED_OFFSET_KV"] = 29 payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "FP8_E4M3"}, "nodes": [ { @@ -159,21 +169,26 @@ def _fp8_bwd_payload(*, ragged=False, output_dtype="HALF", mxfp8=False): }, "repro_metadata": {"ragged_tensor_names": [""] if ragged else []}, } + tensors = payload["tensors"] if ragged: - payload["tensors"]["19"] = {"uid": 19, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [9, 7]} - payload["tensors"]["20"] = {"uid": 20, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [15, 11]} - for key in ("1", "2", "3", "4", "21", "22", "23"): - payload["tensors"][key]["ragged_offset_uid"] = 99 + tensors["19"] = {"uid": 19, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [9, 7]} + tensors["20"] = {"uid": 20, "data_type": "INT32", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [15, 11]} + tensors["28"] = {"uid": 28, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 9, 16]} + tensors["29"] = {"uid": 29, "data_type": "INT64", "dim": [3, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 15, 26]} + payload["tensors"] = tensor_list(tensors) return payload def test_operations_distinguish_fp8_and_non_fp8_tags(): - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_FWD"}]}) == "sdpa_fwd" - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_BWD"}]}) == "sdpa_bwd" - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_FP8_FWD"}]}) == "sdpa_fp8_fwd" - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_FP8_BWD"}]}) == "sdpa_fp8_bwd" - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_MXFP8_FWD"}]}) == "sdpa_fp8_fwd" - assert operations.detect_operation_key({"nodes": [{"tag": "SDPA_MXFP8_BWD"}]}) == "sdpa_fp8_bwd" + def payload(tag): + return {"json_version": "2.0", "gid": 1, "nodes": [{"tag": tag}], "tensors": []} + + assert operations.detect_operation_key(payload("SDPA_FWD")) == "sdpa_fwd" + assert operations.detect_operation_key(payload("SDPA_BWD")) == "sdpa_bwd" + assert operations.detect_operation_key(payload("SDPA_FP8_FWD")) == "sdpa_fp8_fwd" + assert operations.detect_operation_key(payload("SDPA_FP8_BWD")) == "sdpa_fp8_bwd" + assert operations.detect_operation_key(payload("SDPA_MXFP8_FWD")) == "sdpa_fp8_fwd" + assert operations.detect_operation_key(payload("SDPA_MXFP8_BWD")) == "sdpa_fp8_bwd" def test_build_fp8_fwd_cfg_extracts_output_type_and_stats(): @@ -191,6 +206,16 @@ def test_build_fp8_fwd_cfg_extracts_output_type_and_stats(): assert cfg["seq_len_kv"] == [19, 17] +def test_build_fp8_fwd_cfg_detects_ragged_from_offset_inputs(): + payload = _fp8_fwd_payload(ragged=True) + + cfg = sdpa_fp8_fwd.build_cfg("{}", payload, seed=123) + + assert cfg["is_ragged"] is True + assert cfg["seq_len_q"] == [13, 11] + assert cfg["seq_len_kv"] == [19, 17] + + def test_build_fp8_fwd_cfg_infers_paged_block_size(): cfg = sdpa_fp8_fwd.build_cfg("{}", _fp8_fwd_payload(paged=True), seed=123) assert cfg["is_paged"] is True diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_graph_uid.py b/tools/cudnn_repro/tests/test_cudnn_repro_graph_uid.py deleted file mode 100644 index 9788fb55..00000000 --- a/tools/cudnn_repro/tests/test_cudnn_repro_graph_uid.py +++ /dev/null @@ -1,29 +0,0 @@ -import json - -import cudnn_repro.log_parser as log_parser - - -def test_iter_context_entries_prefers_execution_order_with_graph_uid(): - payload1 = {"context": {"io_data_type": "HALF"}, "graph_uid": 11, "nodes": [{"tag": "SDPA_FWD"}], "tensors": {}} - payload2 = {"context": {"io_data_type": "BFLOAT16"}, "graph_uid": 22, "nodes": [{"tag": "SDPA_BWD"}], "tensors": {}} - lines = [ - json.dumps(payload1), - "[cudnn_frontend] INFO: Executing graph_uid 11", - json.dumps(payload2), - "[cudnn_frontend] INFO: Executing graph_uid 22", - "[cudnn_frontend] INFO: Executing graph_uid 11", - ] - - entries = list(log_parser.iter_context_entries(lines)) - - assert [payload.get("graph_uid") for _, payload in entries] == [11, 22, 11] - assert [raw_line for raw_line, _ in entries] == [json.dumps(payload1), json.dumps(payload2), json.dumps(payload1)] - - -def test_iter_context_entries_falls_back_without_execution_markers(): - payload1 = {"context": {"io_data_type": "HALF"}, "graph_uid": 11, "nodes": [{"tag": "SDPA_FWD"}], "tensors": {}} - payload2 = {"context": {"io_data_type": "BFLOAT16"}, "graph_uid": 22, "nodes": [{"tag": "SDPA_BWD"}], "tensors": {}} - - entries = list(log_parser.iter_context_entries([json.dumps(payload1), json.dumps(payload2)])) - - assert [payload.get("graph_uid") for _, payload in entries] == [11, 22] diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_log_parser.py b/tools/cudnn_repro/tests/test_cudnn_repro_log_parser.py new file mode 100644 index 00000000..8303ed0c --- /dev/null +++ b/tools/cudnn_repro/tests/test_cudnn_repro_log_parser.py @@ -0,0 +1,100 @@ +import json + +import cudnn_repro.log_parser as log_parser + + +def payload(gid, tag, dtype): + return {"context": {"io_data_type": dtype}, "gid": gid, "nodes": [{"tag": tag}], "tensors": []} + + +def test_iter_context_entries_prefers_execution_order_with_gid(): + payload1 = payload(11, "SDPA_FWD", "HALF") + payload2 = payload(22, "SDPA_BWD", "BFLOAT16") + lines = [ + json.dumps(payload1), + "[cudnn_frontend] INFO: Executing gid 11", + json.dumps(payload2), + "[cudnn_frontend] INFO: Executing gid 22", + "[cudnn_frontend] INFO: Executing gid 11", + ] + + entries = list(log_parser.iter_context_entries(lines)) + + assert [payload.get("gid") for _, payload in entries] == [11, 22, 11] + assert [raw_line for raw_line, _ in entries] == [json.dumps(payload1), json.dumps(payload2), json.dumps(payload1)] + + +def test_iter_context_entries_falls_back_without_execution_markers(): + payload1 = payload(11, "SDPA_FWD", "HALF") + payload2 = payload(22, "SDPA_BWD", "BFLOAT16") + + entries = list(log_parser.iter_context_entries([json.dumps(payload1), json.dumps(payload2)])) + + assert [payload.get("gid") for _, payload in entries] == [11, 22] + + +def test_iter_context_entries_does_not_reuse_tensor_dumps_across_gids(): + payload1 = payload(11, "SDPA_FWD", "HALF") + payload2 = payload(22, "SDPA_BWD", "HALF") + payload1["tensors"] = [{"uid": 5}] + payload2["tensors"] = [{"uid": 5}] + lines = [ + json.dumps(payload1), + json.dumps(payload2), + "[cudnn_frontend] INFO: Executing gid 11", + "[cudnn_frontend] INFO: Tensor Dump uid: 5 Name: Data: [13, 11]", + "[cudnn_frontend] INFO: Executing gid 22", + ] + + entries = list(log_parser.iter_context_entries(lines)) + + assert entries[0][1]["tensors"][0]["pass_by_value"] == [13, 11] + assert "pass_by_value" not in entries[-1][1]["tensors"][0] + + +def test_iter_context_entries_prefers_current_tensor_dump(): + payload1 = payload(11, "SDPA_FWD", "HALF") + payload2 = payload(22, "SDPA_BWD", "HALF") + payload1["tensors"] = [{"uid": 5}] + payload2["tensors"] = [{"uid": 5}] + lines = [ + json.dumps(payload1), + json.dumps(payload2), + "[cudnn_frontend] INFO: Executing gid 11", + "[cudnn_frontend] INFO: Tensor Dump uid: 5 Name: Data: [1]", + "[cudnn_frontend] INFO: Executing gid 22", + "[cudnn_frontend] INFO: Tensor Dump uid: 5 Name: Data: [2]", + ] + + entries = list(log_parser.iter_context_entries(lines)) + + assert entries[-1][1]["tensors"][0]["pass_by_value"] == [2] + + +def test_iter_context_entries_attaches_dumped_ragged_offset_tensor(): + payload1 = payload(11, "SDPA_FWD", "HALF") + payload1["tensors"] = [{"uid": 5, "ragged_offset_uid": 15}] + lines = [ + json.dumps(payload1), + "[cudnn_frontend] INFO: Executing gid 11", + "[cudnn_frontend] INFO: Tensor Dump uid: 15 Name: Data: [0, 128]", + ] + + entries = list(log_parser.iter_context_entries(lines)) + + assert entries[-1][1]["tensors"][-1] == {"uid": 15, "pass_by_value": [0, 128]} + + +def test_iter_context_entries_ignores_dumps_for_unknown_gid(): + payload1 = payload(22, "SDPA_BWD", "HALF") + payload1["tensors"] = [{"uid": 5}] + lines = [ + json.dumps(payload1), + "[cudnn_frontend] INFO: Executing gid 11", + "[cudnn_frontend] INFO: Tensor Dump uid: 5 Name: Data: [1]", + "[cudnn_frontend] INFO: Executing gid 22", + ] + + entries = list(log_parser.iter_context_entries(lines)) + + assert "pass_by_value" not in entries[-1][1]["tensors"][0] diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_schema.py b/tools/cudnn_repro/tests/test_cudnn_repro_schema.py index cd7915ff..2d57be4f 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_schema.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_schema.py @@ -1,10 +1,16 @@ +import pytest + import cudnn_repro.operations as operations import cudnn_repro.repro_command as repro_command import cudnn_repro.sdpa_fwd as sdpa_fwd +from .helpers import tensor_list + -def fwd_payload(*, graph_uid=1, unfuse_fma=False): +def fwd_payload(*, gid=1, unfuse_fma=False): payload = { + "json_version": "2.0", + "gid": gid, "context": {"io_data_type": "FLOAT16"}, "nodes": [ { @@ -20,20 +26,22 @@ def fwd_payload(*, graph_uid=1, unfuse_fma=False): "padding_mask": False, } ], - "tensors": { - "1": {"uid": 1, "name": "", "dim": [1, 3, 16, 64], "stride": [3072, 64, 192, 1]}, - "2": {"uid": 2, "name": "", "dim": [1, 1, 16, 64], "stride": [1024, 64, 64, 1]}, - "3": {"uid": 3, "name": "", "dim": [1, 1, 16, 64], "stride": [1024, 64, 64, 1]}, - "4": {"uid": 4, "name": "sdpa_fwd::O", "dim": [1, 3, 16, 64], "stride": [3072, 64, 192, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [1, 3, 16, 64], "stride": [3072, 64, 192, 1]}, + "2": {"uid": 2, "dim": [1, 1, 16, 64], "stride": [1024, 64, 64, 1]}, + "3": {"uid": 3, "dim": [1, 1, 16, 64], "stride": [1024, 64, 64, 1]}, + "4": {"uid": 4, "name": "sdpa_fwd::O", "dim": [1, 3, 16, 64], "stride": [3072, 64, 192, 1]}, + } + ), } - if graph_uid is not None: - payload["graph_uid"] = graph_uid return payload def test_build_cfg_maps_causal_without_explicit_right_bound(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "FLOAT16"}, "nodes": [ { @@ -47,12 +55,14 @@ def test_build_cfg_maps_causal_without_explicit_right_bound(): "right_bound": None, } ], - "tensors": { - "1": {"uid": 1, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "2": {"uid": 2, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, + "2": {"uid": 2, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, + "3": {"uid": 3, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, + "4": {"uid": 4, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, + } + ), } cfg = operations.select_operation(payload).build_cfg("{}", payload) @@ -61,8 +71,18 @@ def test_build_cfg_maps_causal_without_explicit_right_bound(): assert cfg["diag_align"] == 0 +def test_select_operation_rejects_old_graph_json(): + payload = fwd_payload() + payload["json_version"] = "1.0" + + with pytest.raises(ValueError, match="Unsupported graph JSON version"): + operations.select_operation(payload) + + def test_build_cfg_preserves_logged_tensor_layout(): payload = { + "json_version": "2.0", + "gid": 1, "context": {"io_data_type": "FLOAT16"}, "nodes": [ { @@ -75,12 +95,14 @@ def test_build_cfg_preserves_logged_tensor_layout(): "right_bound": None, } ], - "tensors": { - "1": {"uid": 1, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "2": {"uid": 2, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "3": {"uid": 3, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "4": {"uid": 4, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - }, + "tensors": tensor_list( + { + "1": {"uid": 1, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, + "2": {"uid": 2, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, + "3": {"uid": 3, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, + "4": {"uid": 4, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, + } + ), } cfg = operations.select_operation(payload).build_cfg("{}", payload) @@ -92,6 +114,56 @@ def test_build_cfg_preserves_logged_tensor_layout(): assert cfg["right_bound"] is None +def test_build_cfg_detects_ragged_from_offset_inputs(): + payload = fwd_payload() + node = payload["nodes"][0] + node["padding_mask"] = True + node["inputs"].update({"SEQ_LEN_Q": 5, "SEQ_LEN_KV": 6, "RAGGED_OFFSET_Q": 7, "RAGGED_OFFSET_KV": 8}) + payload["tensors"].extend( + tensor_list( + { + "5": {"uid": 5, "data_type": "INT32", "dim": [1, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [13]}, + "6": {"uid": 6, "data_type": "INT32", "dim": [1, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [11]}, + "7": {"uid": 7, "data_type": "INT64", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 13]}, + "8": {"uid": 8, "data_type": "INT64", "dim": [2, 1, 1, 1], "stride": [1, 1, 1, 1], "pass_by_value": [0, 11]}, + } + ) + ) + + cfg = operations.select_operation(payload).build_cfg("{}", payload) + + assert cfg["is_ragged"] is True + assert cfg["is_padding"] is True + assert cfg["seq_len_q"] == [13] + assert cfg["seq_len_kv"] == [11] + + +def test_build_cfg_detects_ragged_from_tensor_offset_refs(): + payload = fwd_payload() + node = payload["nodes"][0] + node["padding_mask"] = True + node["inputs"].update({"SEQ_LEN_Q": 5, "SEQ_LEN_KV": 6}) + for entry, offset_uid in zip(payload["tensors"][:4], [7, 8, 8, 7]): + entry["ragged_offset_uid"] = offset_uid + payload["tensors"].extend( + tensor_list( + { + "5": {"uid": 5, "data_type": "INT32", "dim": [1], "stride": [1], "pass_by_value": [13]}, + "6": {"uid": 6, "data_type": "INT32", "dim": [1], "stride": [1], "pass_by_value": [11]}, + "7": {"uid": 7, "pass_by_value": [0, 13]}, + "8": {"uid": 8, "pass_by_value": [0, 11]}, + } + ) + ) + + annotated = operations.select_operation(payload).extract_and_annotate("{}", payload, "") + cfg = operations.select_operation(annotated).build_cfg("{}", annotated) + + assert cfg["is_ragged"] is True + assert annotated["repro_metadata"]["ragged_offset_q"] == [0, 13] + assert annotated["repro_metadata"]["ragged_offset_kv"] == [0, 11] + + def test_build_command_normalizes_enum_fields(): cfg = { "data_type": "torch.float16", @@ -145,8 +217,8 @@ def test_build_cfg_preserves_rope(): assert "'with_rope': True" in command -def test_graph_uid_payload_ignores_unrelated_ragged_log_text(): - payload = fwd_payload(graph_uid=9) +def test_payload_ignores_unrelated_ragged_log_text(): + payload = fwd_payload(gid=9) log_text = "Backend Tensor named 'sdpa_fwd::O' with UID 4\n" "Id: 4\n" "raggedOffset: Enabled UID: 99\n" annotated_payload = sdpa_fwd.extract_and_annotate("{}", payload, log_text) @@ -154,14 +226,3 @@ def test_graph_uid_payload_ignores_unrelated_ragged_log_text(): assert annotated_payload["repro_metadata"]["ragged_tensor_names"] == [] assert cfg["is_ragged"] is False - - -def test_legacy_payload_keeps_ragged_log_text_fallback(): - payload = fwd_payload(graph_uid=None) - log_text = "Backend Tensor named 'sdpa_fwd::O' with UID 4\n" "Id: 4\n" "raggedOffset: Enabled UID: 99\n" - - annotated_payload = sdpa_fwd.extract_and_annotate("{}", payload, log_text) - cfg = sdpa_fwd.build_cfg("{}", annotated_payload, seed=123) - - assert annotated_payload["repro_metadata"]["ragged_tensor_names"] == ["sdpa_fwd::O"] - assert cfg["is_ragged"] is True