diff --git a/maint/tuning/coll/mpir/generic.json b/maint/tuning/coll/mpir/generic.json index 02b1bf36155..4a0bebf9d9d 100644 --- a/maint/tuning/coll/mpir/generic.json +++ b/maint/tuning/coll/mpir/generic.json @@ -36,13 +36,16 @@ { "algorithm=MPIR_Bcast_intra_binomial":{} }, - "avg_msg_size<=524288": - { - "algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{} - }, "avg_msg_size=any": { - "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + "per_proc_msg_size<=65536": + { + "algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{} + }, + "per_proc_msg_size=any": + { + "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + } } }, "comm_hierarchy=any": @@ -51,13 +54,16 @@ { "algorithm=MPIR_Bcast_intra_binomial":{} }, - "avg_msg_size<=524288": - { - "algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{} - }, "avg_msg_size=any": { - "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + "per_proc_msg_size<=65536": + { + "algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{} + }, + "per_proc_msg_size=any": + { + "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + } } } }, @@ -69,22 +75,25 @@ { "algorithm=MPIR_Bcast_intra_smp":{} }, - "avg_msg_size<=12288": - { - "algorithm=MPIR_Bcast_intra_binomial":{} - }, "avg_msg_size=any": { - "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + "per_proc_msg_size<=12288": + { + "algorithm=MPIR_Bcast_intra_binomial":{} + }, + "per_proc_msg_size=any": + { + "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} + } } }, "comm_hierarchy=any": { - "avg_msg_size<=12288": + "per_proc_msg_size<=12288": { "algorithm=MPIR_Bcast_intra_binomial":{} }, - "avg_msg_size=any": + "per_proc_msg_size=any": { "algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{} } diff --git a/src/mpi/coll/src/csel.c b/src/mpi/coll/src/csel.c index 1253f847f66..8074b687f70 100644 --- a/src/mpi/coll/src/csel.c +++ b/src/mpi/coll/src/csel.c @@ -38,6 +38,8 @@ typedef enum { CSEL_NODE_TYPE__OPERATOR__AVG_MSG_SIZE_LT, CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LE, CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT, + CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE, + CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT, CSEL_NODE_TYPE__OPERATOR__COUNT_LE, CSEL_NODE_TYPE__OPERATOR__COUNT_LT_POW2, @@ -89,6 +91,12 @@ typedef struct csel_node { struct { int val; } total_msg_size_lt; + struct { + int val; + } per_proc_msg_size_le; + struct { + int val; + } per_proc_msg_size_lt; struct { int val; } count_le; @@ -197,6 +205,12 @@ static void print_tree(csel_node_s * node) case CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT: nprintf("total_msg_size < %d\n", node->u.total_msg_size_lt.val); break; + case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE: + nprintf("avg_msg_size <= %d\n", node->u.avg_msg_size_le.val); + break; + case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT: + nprintf("avg_msg_size < %d\n", node->u.avg_msg_size_lt.val); + break; case CSEL_NODE_TYPE__CONTAINER: nprintf("container\n"); break; @@ -488,6 +502,12 @@ static csel_node_s *parse_json_tree(struct json_object *obj, } else if (!strncmp(ckey, "total_msg_size<", strlen("total_msg_size<"))) { tmp->type = CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT; tmp->u.total_msg_size_lt.val = atoi(ckey + strlen("total_msg_size<")); + } else if (!strncmp(ckey, "per_proc_msg_size<=", strlen("per_proc_msg_size<="))) { + tmp->type = CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE; + tmp->u.avg_msg_size_le.val = atoi(ckey + strlen("per_proc_msg_size<=")); + } else if (!strncmp(ckey, "per_proc_msg_size<", strlen("per_proc_msg_size<"))) { + tmp->type = CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT; + tmp->u.avg_msg_size_lt.val = atoi(ckey + strlen("per_proc_msg_size<")); } else if (!strcmp(ckey, "is_commutative=yes")) { tmp->type = CSEL_NODE_TYPE__OPERATOR__IS_COMMUTATIVE; tmp->u.is_commutative.val = true; @@ -1177,6 +1197,11 @@ static inline MPI_Aint get_total_msgsize(MPIR_Csel_coll_sig_s coll_info) return total_bytes; } +static inline MPI_Aint get_perproc_msgsize(MPIR_Csel_coll_sig_s coll_info) +{ + return get_avg_msgsize(coll_info) / coll_info.comm_ptr->local_size; +} + void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) { csel_s *csel = (csel_s *) csel_; @@ -1278,6 +1303,20 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info) node = node->failure; break; + case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE: + if (get_perproc_msgsize(coll_info) <= node->u.per_proc_msg_size_le.val) + node = node->success; + else + node = node->failure; + break; + + case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT: + if (get_perproc_msgsize(coll_info) < node->u.per_proc_msg_size_lt.val) + node = node->success; + else + node = node->failure; + break; + case CSEL_NODE_TYPE__OPERATOR__COUNT_LE: if (get_count(coll_info) <= node->u.count_le.val) node = node->success;