Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions maint/tuning/coll/mpir/generic.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
{
"algorithm=MPIR_Bcast_intra_binomial":{}
},
"avg_msg_size<=524288":
{
"algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{}
},
"avg_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
"per_proc_msg_size<=65536":
{
"algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{}
},
"per_proc_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
}
}
},
"comm_hierarchy=any":
Expand All @@ -51,13 +54,16 @@
{
"algorithm=MPIR_Bcast_intra_binomial":{}
},
"avg_msg_size<=524288":
{
"algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{}
},
"avg_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
"per_proc_msg_size<=65536":
{
"algorithm=MPIR_Bcast_intra_scatter_recursive_doubling_allgather":{}
},
"per_proc_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
}
}
}
},
Expand All @@ -69,22 +75,25 @@
{
"algorithm=MPIR_Bcast_intra_smp":{}
},
"avg_msg_size<=12288":
{
"algorithm=MPIR_Bcast_intra_binomial":{}
},
"avg_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
"per_proc_msg_size<=12288":
{
"algorithm=MPIR_Bcast_intra_binomial":{}
},
"per_proc_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
}
}
},
"comm_hierarchy=any":
{
"avg_msg_size<=12288":
"per_proc_msg_size<=12288":
{
"algorithm=MPIR_Bcast_intra_binomial":{}
},
"avg_msg_size=any":
"per_proc_msg_size=any":
{
"algorithm=MPIR_Bcast_intra_scatter_ring_allgather":{}
}
Expand Down
39 changes: 39 additions & 0 deletions src/mpi/coll/src/csel.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ typedef enum {
CSEL_NODE_TYPE__OPERATOR__AVG_MSG_SIZE_LT,
CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LE,
CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT,
CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE,
CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT,

CSEL_NODE_TYPE__OPERATOR__COUNT_LE,
CSEL_NODE_TYPE__OPERATOR__COUNT_LT_POW2,
Expand Down Expand Up @@ -89,6 +91,12 @@ typedef struct csel_node {
struct {
int val;
} total_msg_size_lt;
struct {
int val;
} per_proc_msg_size_le;
struct {
int val;
} per_proc_msg_size_lt;
struct {
int val;
} count_le;
Expand Down Expand Up @@ -197,6 +205,12 @@ static void print_tree(csel_node_s * node)
case CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT:
nprintf("total_msg_size < %d\n", node->u.total_msg_size_lt.val);
break;
case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE:
nprintf("avg_msg_size <= %d\n", node->u.avg_msg_size_le.val);
break;
case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT:
nprintf("avg_msg_size < %d\n", node->u.avg_msg_size_lt.val);
break;
case CSEL_NODE_TYPE__CONTAINER:
nprintf("container\n");
break;
Expand Down Expand Up @@ -488,6 +502,12 @@ static csel_node_s *parse_json_tree(struct json_object *obj,
} else if (!strncmp(ckey, "total_msg_size<", strlen("total_msg_size<"))) {
tmp->type = CSEL_NODE_TYPE__OPERATOR__TOTAL_MSG_SIZE_LT;
tmp->u.total_msg_size_lt.val = atoi(ckey + strlen("total_msg_size<"));
} else if (!strncmp(ckey, "per_proc_msg_size<=", strlen("per_proc_msg_size<="))) {
tmp->type = CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE;
tmp->u.avg_msg_size_le.val = atoi(ckey + strlen("per_proc_msg_size<="));
} else if (!strncmp(ckey, "per_proc_msg_size<", strlen("per_proc_msg_size<"))) {
tmp->type = CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT;
tmp->u.avg_msg_size_lt.val = atoi(ckey + strlen("per_proc_msg_size<"));
} else if (!strcmp(ckey, "is_commutative=yes")) {
tmp->type = CSEL_NODE_TYPE__OPERATOR__IS_COMMUTATIVE;
tmp->u.is_commutative.val = true;
Expand Down Expand Up @@ -1177,6 +1197,11 @@ static inline MPI_Aint get_total_msgsize(MPIR_Csel_coll_sig_s coll_info)
return total_bytes;
}

static inline MPI_Aint get_perproc_msgsize(MPIR_Csel_coll_sig_s coll_info)
{
return get_avg_msgsize(coll_info) / coll_info.comm_ptr->local_size;
}

void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info)
{
csel_s *csel = (csel_s *) csel_;
Expand Down Expand Up @@ -1278,6 +1303,20 @@ void *MPIR_Csel_search(void *csel_, MPIR_Csel_coll_sig_s coll_info)
node = node->failure;
break;

case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LE:
if (get_perproc_msgsize(coll_info) <= node->u.per_proc_msg_size_le.val)
node = node->success;
else
node = node->failure;
break;

case CSEL_NODE_TYPE__OPERATOR__PER_PROC_MSG_SIZE_LT:
if (get_perproc_msgsize(coll_info) < node->u.per_proc_msg_size_lt.val)
node = node->success;
else
node = node->failure;
break;

case CSEL_NODE_TYPE__OPERATOR__COUNT_LE:
if (get_count(coll_info) <= node->u.count_le.val)
node = node->success;
Expand Down