Skip to content
88 changes: 37 additions & 51 deletions src/mpi/comm/ulfm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* implementation. Since ULFM require local discovery, we should remove that
*/

#ifndef MPID_COMM_AGREE
/* maintain a list of failed process in comm_world */
/* NOTE: we need maintain the order of failed_procs as the show up. We do it here because
* it isn't fair to require PMI to do it.
Expand Down Expand Up @@ -67,9 +68,13 @@ static void parse_failed_procs_string(char *failed_procs_string)
token = strtok(NULL, delim);
}
}
#endif

int MPIR_Comm_get_failed_impl(MPIR_Comm * comm_ptr, MPIR_Group ** failed_group_ptr)
{
#ifdef MPID_COMM_AGREE
return MPID_Comm_get_failed(comm_ptr, failed_group_ptr);
#else
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

Expand Down Expand Up @@ -119,80 +124,60 @@ int MPIR_Comm_get_failed_impl(MPIR_Comm * comm_ptr, MPIR_Group ** failed_group_p
return mpi_errno;
fn_fail:
goto fn_exit;
#endif
}

/* comm shrink impl; assumes that standard error checking has already taken
* place in the calling function */
/* Supposedly caller already agreed on the result of MPIX_Comm_get_failed
* by running MPIX_Comm_agree. Thus, shrink is merely MPI_Comm_create_group.
*/
int MPIR_Comm_shrink_impl(MPIR_Comm * comm_ptr, MPIR_Comm ** newcomm_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Group *global_failed = NULL, *comm_grp = NULL, *new_group_ptr = NULL;
int attempts = 0;

MPIR_FUNC_ENTER;

/* TODO - Implement this function for intercommunicators */
MPIR_Comm_group_impl(comm_ptr, &comm_grp);
MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM);

do {
int coll_attr = 0;

MPID_Comm_get_all_failed_procs(comm_ptr, &global_failed, MPIR_SHRINK_TAG);
/* Ignore the mpi_errno value here as it will definitely communicate
* with failed procs */
MPIR_Group *comm_grp;
mpi_errno = MPIR_Comm_group_impl(comm_ptr, &comm_grp);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Group_difference_impl(comm_grp, global_failed, &new_group_ptr);
MPIR_ERR_CHECK(mpi_errno);
if (MPIR_Group_empty != global_failed)
MPIR_Group_release(global_failed);

mpi_errno = MPIR_Comm_create_group_impl(comm_ptr, new_group_ptr, MPIR_SHRINK_TAG,
newcomm_ptr);
if (*newcomm_ptr == NULL) {
coll_attr = MPIR_ERR_PROC_FAILED;
} else if (mpi_errno) {
coll_attr =
MPIX_ERR_PROC_FAILED ==
MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
MPIR_Comm_release(*newcomm_ptr);
}
MPIR_Group *global_failed;
mpi_errno = MPIR_Comm_get_failed_impl(comm_ptr, &global_failed);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPII_Allreduce_group(MPI_IN_PLACE, &coll_attr, 1, MPIR_INT_INTERNAL, MPI_MAX,
comm_ptr, new_group_ptr, MPIR_SHRINK_TAG,
MPIR_COLL_ATTR_SYNC);
MPIR_Group_release(new_group_ptr);
MPIR_Group *new_group_ptr;
mpi_errno = MPIR_Group_difference_impl(comm_grp, global_failed, &new_group_ptr);
MPIR_ERR_CHECK(mpi_errno);

if (coll_attr) {
if (*newcomm_ptr != NULL && MPIR_Object_get_ref(*newcomm_ptr) > 0) {
MPIR_Object_set_ref(*newcomm_ptr, 1);
MPIR_Comm_release(*newcomm_ptr);
}
if (MPIR_Object_get_ref(new_group_ptr) > 0) {
MPIR_Object_set_ref(new_group_ptr, 1);
MPIR_Group_release(new_group_ptr);
}
} else {
mpi_errno = MPI_SUCCESS;
goto fn_exit;
}
} while (++attempts < 5);
if (MPIR_Group_empty != global_failed) {
MPIR_Group_release(global_failed);
}

goto fn_fail;
mpi_errno = MPIR_Comm_create_group_impl(comm_ptr, new_group_ptr, MPIR_SHRINK_TAG, newcomm_ptr);
/* FIXME: what if user have not run MPIX_Comm_agree or there are new failed procs?
* We need handle MPIX_ERR_PROC_FAILED.
*/
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
MPIR_Group_release(comm_grp);
if (new_group_ptr) {
MPIR_Group_release(new_group_ptr);
}
if (comm_grp) {
MPIR_Group_release(comm_grp);
}
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
if (*newcomm_ptr)
MPIR_Object_set_ref(*newcomm_ptr, 0);
MPIR_Object_set_ref(global_failed, 0);
MPIR_Object_set_ref(new_group_ptr, 0);
goto fn_exit;
}

int MPIR_Comm_agree_impl(MPIR_Comm * comm_ptr, int *flag)
{
#ifdef MPID_COMM_AGREE
return MPID_Comm_agree(comm_ptr, flag);
#else
int mpi_errno = MPI_SUCCESS, mpi_errno_tmp = MPI_SUCCESS;
MPIR_Group *comm_grp = NULL, *failed_grp = NULL, *new_group_ptr = NULL, *global_failed = NULL;
int result, success = 1;
Expand Down Expand Up @@ -264,4 +249,5 @@ int MPIR_Comm_agree_impl(MPIR_Comm * comm_ptr, int *flag)
return mpi_errno;
fn_fail:
goto fn_exit;
#endif
}
5 changes: 5 additions & 0 deletions src/mpid/ch4/include/mpidch4.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ int MPID_Waitall_enqueue(int count, MPI_Request * array_of_requests,
MPI_Status * array_of_statuses);
int MPID_Abort(struct MPIR_Comm *comm, int mpi_errno, int exit_code, const char *error_msg);

#define MPID_COMM_AGREE 1
int MPID_Comm_agree(MPIR_Comm * comm, int *flag);
int MPID_Comm_get_failed(MPIR_Comm * comm_ptr, MPIR_Group ** failed_group_ptr);


/* This function is not exposed to the upper layers but functions in a way
* similar to the functions above. Other CH4-level functions should call this
* function to query locality. This function will determine whether to call the
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ typedef struct MPIDIG_comm_t {

typedef struct MPIDI_Devcomm_t {
struct {
int comm_agree_epoch;
/* The first fields are used by the AM(MPIDIG) apis */
MPIDIG_comm_t am;
/* for netmod internal send/recv (e.g. am_tag_{send,recv}, pipeline, rndv_{read,write} */
Expand Down
37 changes: 26 additions & 11 deletions src/mpid/ch4/netmod/ofi/ofi_events.c
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret)
struct fi_cq_err_entry e;
char err_data[MPIDI_OFI_MAX_ERR_DATA_SIZE];
MPIR_Request *req;
int event_id;
ssize_t ret_cqerr;
MPIR_FUNC_ENTER;

Expand Down Expand Up @@ -544,10 +545,9 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret)
break;

default:
MPIR_ERR_SETFATALANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME,
fi_strerror(e.err));
MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(e.err));
}

break;
Expand All @@ -557,7 +557,7 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret)
/* Clean up the request. Reference MPIDI_OFI_recv_event.
* NOTE: assuming only the receive request can be cancelled and reach here
*/
int event_id = MPIDI_OFI_REQUEST(req, event_id);
event_id = MPIDI_OFI_REQUEST(req, event_id);
switch (event_id) {
case MPIDI_OFI_EVENT_DYNPROC_DONE:
dynproc_done_event(vci, e.op_context, req);
Expand Down Expand Up @@ -588,17 +588,32 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret)
break;

default:
MPIR_ERR_SETFATALANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(e.err));
req = MPIDI_OFI_context_to_request(e.op_context);
event_id = MPIDI_OFI_REQUEST(req, event_id);
switch (event_id) {
case MPIDI_OFI_EVENT_AM_SEND:
/* set req->status.MPI_ERROR */
MPIR_ERR_SET2(req->status.MPI_ERROR, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(e.err));
mpi_errno = am_isend_event(vci, NULL, req);
break;
default:
/* FIXME: application can't handle error in progress due to loss of
* context. We should try best to set error in req->status instead.
*/
MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(e.err));
}
}

break;

default:
MPIR_ERR_SETFATALANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(errno));
MPIR_ERR_SETANDJUMP2(mpi_errno, MPI_ERR_OTHER, "**ofid_poll",
"**ofid_poll %s %s",
MPIDI_OFI_DEFAULT_NIC_NAME, fi_strerror(errno));
}

fn_exit:
Expand Down
8 changes: 6 additions & 2 deletions src/mpid/ch4/netmod/ofi/ofi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret);
#define MPIDI_OFI_CALL_RETRY_AM(FUNC,vci_,STR) \
do { \
ssize_t _ret; \
int _retry = MPIR_CVAR_CH4_OFI_MAX_EAGAIN_RETRY; \
do { \
_ret = FUNC; \
if (likely(_ret==0)) break; \
Expand All @@ -144,9 +145,12 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret);
"**ofid_"#STR" %s %s", \
MPIDI_OFI_DEFAULT_NIC_NAME, \
fi_strerror(-_ret)); \
if (_retry > 0) { \
_retry--; \
MPIR_ERR_CHKANDJUMP(_retry == 0, mpi_errno, MPIX_ERR_EAGAIN, "**eagain"); \
} \
mpi_errno = MPIDI_OFI_progress_do_queue(vci_); \
if (mpi_errno != MPI_SUCCESS) \
MPIR_ERR_CHECK(mpi_errno); \
MPIR_ERR_CHECK(mpi_errno); \
} while (_ret == -FI_EAGAIN); \
} while (0)

Expand Down
12 changes: 0 additions & 12 deletions src/mpid/ch4/netmod/ofi/subconfigure.m4
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,6 @@ AM_COND_IF([BUILD_CH4_NETMOD_OFI],[
PAC_LIBS_ADD([-lfabric])
fi

# check for libfabric dependence libs
pcdir=""
if test "${ofi_embedded}" = "yes" ; then
pcdir="${main_top_builddir}/modules/libfabric"
elif test -f ${with_libfabric}/lib/pkgconfig/libfabric.pc ; then
pcdir="${with_libfabric}/lib/pkgconfig"
fi
PAC_LIB_DEPS(fabric, $pcdir)
if test "x$ac_libfabric_deps" != "x"; then
PAC_APPEND_FLAG([${ac_libfabric_deps}],[WRAPPER_LIBS])
fi

AC_ARG_ENABLE(ofi-domain, [
--enable-ofi-domain - Use fi_domain for vci contexts. This is the default.
Use --disable-ofi-domain to use fi_contexts within
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/Makefile.mk
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mpi_core_sources += src/mpid/ch4/src/ch4_globals.c \
src/mpid/ch4/src/ch4_stream_enqueue.c \
src/mpid/ch4/src/ch4_persist.c \
src/mpid/ch4/src/ch4_vci.c \
src/mpid/ch4/src/ch4_ulfm.c \
src/mpid/ch4/src/mpidig_init.c \
src/mpid/ch4/src/mpidig_recvq.c \
src/mpid/ch4/src/mpidig_pt2pt_callbacks.c \
Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/src/ch4_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ int MPID_Comm_commit_pre_hook(MPIR_Comm * comm)
mpi_errno = MPIDIG_init_comm(comm);
MPIR_ERR_CHECK(mpi_errno);

MPIDI_COMM(comm, comm_agree_epoch) = 0;
/* initialize next_am_tag for internal messaging */
int total_tag_bits = get_num_bits(MPIR_Process.attrs.tag_ub);
MPIDI_COMM(comm, next_am_tag) = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch4/src/ch4_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ int MPIDI_init_per_vci(int vci);
int MPIDI_destroy_per_vci(int vci);
int MPIDIG_get_context_index(uint64_t context_id);
uint64_t MPIDIG_generate_win_id(MPIR_Comm * comm_ptr);
int MPIDI_ulfm_init(void);
int MPIDI_ulfm_finalize(void);

/* define CH4_CALL to call netmod or shm API based on is_local */
#ifdef MPIDI_CH4_DIRECT_NETMOD
Expand Down
6 changes: 6 additions & 0 deletions src/mpid/ch4/src/ch4_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,9 @@ int MPID_Init(int requested, int *provided)
mpi_errno = MPIDU_stream_workq_init();
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIDI_ulfm_init();
MPIR_ERR_CHECK(mpi_errno);

/* Create genq for GPU collectives */
mpi_errno =
MPIDU_genq_private_pool_create(MPIR_CVAR_CH4_GPU_COLL_SWAP_BUFFER_SZ,
Expand Down Expand Up @@ -733,6 +736,9 @@ int MPID_Finalize(void)
mpi_errno = MPIDU_stream_workq_finalize();
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIDI_ulfm_finalize();
MPIR_ERR_CHECK(mpi_errno);

for (int i = 0; i < MAX_CH4_MUTEXES; i++) {
int err;
MPID_Thread_mutex_destroy(&MPIDI_global.m[i], &err);
Expand Down
Loading