diff --git a/src/mpid/ch4/netmod/ucx/ucx_am.c b/src/mpid/ch4/netmod/ucx/ucx_am.c index 05f849eac6d..571687e9e41 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_am.c +++ b/src/mpid/ch4/netmod/ucx/ucx_am.c @@ -93,9 +93,11 @@ int MPIDI_UCX_do_am_recv(MPIR_Request * rreq) MPIDI_UCX_ucp_request_t *ucp_request; size_t received_length; ucp_request_param_t param = { - .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_RECV_INFO, + .op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_RECV_INFO | UCP_OP_ATTR_FIELD_USER_DATA, .cb.recv_am = &MPIDI_UCX_am_recv_callback_nbx, .recv_info.length = &received_length, + .user_data = rreq, }; void *data_desc = MPIDI_UCX_AM_RECV_REQUEST(rreq, data_desc); /* note: use in_data_sz to match promised data size */ @@ -103,11 +105,7 @@ int MPIDI_UCX_do_am_recv(MPIR_Request * rreq) data_desc, recv_buf, in_data_sz, ¶m); if (ucp_request == NULL) { /* completed immediately */ - MPIDI_UCX_ucp_request_t tmp_ucp_request; - tmp_ucp_request.req = rreq; - MPIDI_UCX_am_recv_callback_nbx(&tmp_ucp_request, UCS_OK, received_length, NULL); - } else { - ucp_request->req = rreq; + MPIDI_UCX_am_recv_callback_nbx(NULL, UCS_OK, received_length, rreq); } return MPI_SUCCESS; @@ -163,8 +161,7 @@ ucs_status_t MPIDI_UCX_am_nbx_handler(void *arg, const void *header, size_t head void MPIDI_UCX_am_recv_callback_nbx(void *request, ucs_status_t status, size_t length, void *user_data) { - MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request; - MPIR_Request *rreq = ucp_request->req; + MPIR_Request *rreq = user_data; /* FIXME: proper error handling */ MPIR_Assert(status == UCS_OK); @@ -177,8 +174,9 @@ void MPIDI_UCX_am_recv_callback_nbx(void *request, ucs_status_t status, size_t l MPIDIG_recv_done(length, rreq); } MPIDIG_REQUEST(rreq, req->target_cmpl_cb) (rreq); - ucp_request->req = NULL; - ucp_request_release(ucp_request); + if (request) { + ucp_request_release(request); + } } void MPIDI_UCX_am_isend_callback_nbx(void *request, ucs_status_t status, void *user_data) @@ -186,13 +184,11 @@ void MPIDI_UCX_am_isend_callback_nbx(void *request, ucs_status_t status, void *u /* note: only difference from MPIDI_UCX_am_isend_callback is we need * MPL_free in stead of MPIR_gpu_free_host */ - MPIDI_UCX_ucp_request_t *ucp_request = (MPIDI_UCX_ucp_request_t *) request; - MPIR_Request *req = ucp_request->req; + MPIR_Request *req = user_data; int handler_id = MPIDI_UCX_AM_SEND_REQUEST(req, handler_id); MPL_free(MPIDI_UCX_AM_SEND_REQUEST(req, pack_buffer)); MPIDI_UCX_AM_SEND_REQUEST(req, pack_buffer) = NULL; MPIDIG_global.origin_cbs[handler_id] (req); - ucp_request->req = NULL; } #endif diff --git a/src/mpid/ch4/netmod/ucx/ucx_am.h b/src/mpid/ch4/netmod/ucx/ucx_am.h index ecfcb295401..f3699662f9c 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_am.h +++ b/src/mpid/ch4/netmod/ucx/ucx_am.h @@ -42,53 +42,42 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank, #ifdef HAVE_UCP_AM_NBX size_t header_size = sizeof(ucx_hdr) + am_hdr_sz; - void *send_buf, *header, *data_ptr; - /* note: since we are not copying large contig gpu data, it is less useful - * to use MPIR_gpu_malloc_host */ - if (dt_contig) { - /* only need copy headers */ - send_buf = MPL_malloc(header_size, MPL_MEM_OTHER); - MPIR_Assert(send_buf); - header = send_buf; + void *header; + const void *data_ptr; + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb.send = &MPIDI_UCX_am_isend_callback_nbx, + .user_data = sreq, + }; + + header = MPL_malloc(header_size, MPL_MEM_OTHER); + MPIR_Assert(header); - MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr)); - MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz); + MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr)); + MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz); + if (dt_contig) { data_ptr = (char *) data + dt_true_lb; } else { - /* need copy headers and pack data */ - send_buf = MPL_malloc(header_size + data_sz, MPL_MEM_OTHER); - MPIR_Assert(send_buf); - header = send_buf; - data_ptr = (char *) send_buf + header_size; - - MPIR_Memcpy(header, &ucx_hdr, sizeof(ucx_hdr)); - MPIR_Memcpy((char *) header + sizeof(ucx_hdr), am_hdr, am_hdr_sz); - - MPI_Aint actual_pack_bytes; - mpi_errno = MPIR_Typerep_pack(data, count, datatype, 0, data_ptr, data_sz, - &actual_pack_bytes, MPIR_TYPEREP_FLAG_NONE); - MPIR_ERR_CHECK(mpi_errno); - MPIR_Assert(actual_pack_bytes == data_sz); + param.op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE; + param.datatype = dt_ptr->dev.netmod.ucx.ucp_datatype; + MPIR_Datatype_ptr_add_ref(dt_ptr); + data_ptr = data; + data_sz = count; } - ucp_request_param_t param = { - .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK, - .cb.send = &MPIDI_UCX_am_isend_callback_nbx, - }; ucp_request = (MPIDI_UCX_ucp_request_t *) ucp_am_send_nbx(ep, MPIDI_UCX_AM_NBX_HANDLER_ID, header, header_size, data_ptr, data_sz, ¶m); MPIDI_UCX_CHK_REQUEST(ucp_request); /* if send is done, free all resources and complete the request */ if (ucp_request == NULL) { - MPL_free(send_buf); + MPL_free(header); MPIDIG_global.origin_cbs[handler_id] (sreq); goto fn_exit; } - MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = send_buf; + MPIDI_UCX_AM_SEND_REQUEST(sreq, pack_buffer) = header; MPIDI_UCX_AM_SEND_REQUEST(sreq, handler_id) = handler_id; - ucp_request->req = sreq; ucp_request_release(ucp_request); #else /* !HAVE_UCP_AM_NBX */