From d41138d800f9786e7271c4eeae379889e30551aa Mon Sep 17 00:00:00 2001 From: Jiakun Yan Date: Tue, 16 Jan 2024 16:18:08 -0600 Subject: [PATCH] add(putmac): put medium with completion --- lci/runtime/1sided_primitive.c | 91 ++++++++++++++++------------------ lci/runtime/2sided_primitive.c | 4 +- 2 files changed, 43 insertions(+), 52 deletions(-) diff --git a/lci/runtime/1sided_primitive.c b/lci/runtime/1sided_primitive.c index 4bc0da8e..73d698a8 100644 --- a/lci/runtime/1sided_primitive.c +++ b/lci/runtime/1sided_primitive.c @@ -35,8 +35,8 @@ LCI_error_t LCI_putm(LCI_endpoint_t ep, LCI_mbuffer_t mbuffer, int rank, return LCI_ERR_FEATURE_NA; } -LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, - LCI_tag_t tag, uintptr_t remote_completion) +LCI_error_t LCI_putmac(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, + LCI_tag_t tag, uintptr_t remote_completion, LCI_comp_t local_completion, void *user_context) { LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, LCI_MAX_TAG); @@ -48,34 +48,55 @@ LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, "(set by LCI_plist_set_default_comp, " "the default value is LCI_UR_CQ)\n"); LCI_error_t ret = LCI_OK; - if (buffer.length <= LCI_SHORT_SIZE) { + bool is_user_provided_packet = LCII_is_packet(ep->device, buffer.address); + if (local_completion == NULL && buffer.length <= LCI_SHORT_SIZE) { /* if data is this short, we will be able to inline it * no reason to get a packet, allocate a ctx, etc */ ret = LCIS_post_sends(ep->device->endpoint_worker->endpoint, rank, buffer.address, buffer.length, LCII_MAKE_PROTO(ep->gid, LCI_MSG_RDMA_MEDIUM, tag)); + if (ret == LCI_OK && is_user_provided_packet) { + LCII_packet_t* packet = LCII_mbuffer2packet(buffer); + packet->context.poolid = -1; + LCII_free_packet(packet); + } } else { - LCII_packet_t* packet = LCII_alloc_packet_nb(ep->pkpool); - if (packet == NULL) { - // no packet is available - return LCI_ERR_RETRY; + LCII_packet_t* packet; + if (is_user_provided_packet) { + packet = LCII_mbuffer2packet(buffer); + } else { + packet = LCII_alloc_packet_nb(ep->pkpool); + if (packet == NULL) { + // no packet is available + return LCI_ERR_RETRY; + } + memcpy(packet->data.address, buffer.address, buffer.length); } packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD) ? lc_pool_get_local(ep->pkpool) : -1; - memcpy(packet->data.address, buffer.address, buffer.length); LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t)); ctx->data.packet = packet; LCII_initilize_comp_attr(ctx->comp_attr); - LCII_comp_attr_set_free_packet(ctx->comp_attr, 1); - + if (!(is_user_provided_packet && local_completion)) { + LCII_comp_attr_set_free_packet(ctx->comp_attr, 1); + } + if (local_completion) { + ctx->data_type = LCI_MEDIUM; + ctx->data.mbuffer = buffer; + ctx->rank = rank; + ctx->tag = tag; + ctx->user_context = user_context; + LCII_comp_attr_set_comp_type(ctx->comp_attr, ep->cmd_comp_type); + ctx->completion = local_completion; + } ret = LCIS_post_send( ep->device->endpoint_worker->endpoint, rank, packet->data.address, buffer.length, ep->device->heap.segment->mr, LCII_MAKE_PROTO(ep->gid, LCI_MSG_RDMA_MEDIUM, tag), ctx); if (ret == LCI_ERR_RETRY) { - LCII_free_packet(packet); + if (!is_user_provided_packet) LCII_free_packet(packet); LCIU_free(ctx); } } @@ -83,51 +104,23 @@ LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCII_PCOUNTER_ADD(put, (int64_t)buffer.length); } LCI_DBG_Log(LCI_LOG_TRACE, "comm", - "LCI_putm(ep %p, buffer {%p, %lu}, rank %d, tag %u, " - "remote_completion %p) -> %d\n", + "LCI_putmac(ep %p, buffer {%p, %lu}, rank %d, tag %u, " + "remote_completion %p, local_completion %p, user_context %p) -> %d\n", ep, buffer.address, buffer.length, rank, tag, - (void*)remote_completion, ret); + (void*)remote_completion, local_completion, user_context, ret); return ret; } +LCI_error_t LCI_putma(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, + LCI_tag_t tag, uintptr_t remote_completion) +{ + return LCI_putmac(ep, buffer, rank, tag, remote_completion, NULL, NULL); +} + LCI_error_t LCI_putmna(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCI_tag_t tag, uintptr_t remote_completion) { - LCI_DBG_Assert(tag <= LCI_MAX_TAG, "tag %d is too large (maximum: %d)\n", tag, - LCI_MAX_TAG); - LCI_DBG_Assert(buffer.length <= LCI_MEDIUM_SIZE, - "buffer is too large %lu (maximum: %d)\n", buffer.length, - LCI_MEDIUM_SIZE); - LCI_DBG_Assert(remote_completion == LCI_DEFAULT_COMP_REMOTE, - "Only support default remote completion " - "(set by LCI_plist_set_default_comp, " - "the default value is LCI_UR_CQ)\n"); - LCII_packet_t* packet = LCII_mbuffer2packet(buffer); - packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD) - ? lc_pool_get_local(ep->pkpool) - : -1; - - LCII_context_t* ctx = LCIU_malloc(sizeof(LCII_context_t)); - ctx->data.packet = packet; - LCII_initilize_comp_attr(ctx->comp_attr); - LCII_comp_attr_set_free_packet(ctx->comp_attr, 1); - - LCI_error_t ret = LCIS_post_send( - ep->device->endpoint_worker->endpoint, rank, packet->data.address, - buffer.length, ep->device->heap.segment->mr, - LCII_MAKE_PROTO(ep->gid, LCI_MSG_RDMA_MEDIUM, tag), ctx); - if (ret == LCI_ERR_RETRY) { - LCIU_free(ctx); - } - if (ret == LCI_OK) { - LCII_PCOUNTER_ADD(put, (int64_t)buffer.length); - } - LCI_DBG_Log(LCI_LOG_TRACE, "comm", - "LCI_putmna(ep %p, buffer {%p, %lu}, rank %d, tag %u, " - "remote_completion %p) -> %d\n", - ep, buffer.address, buffer.length, rank, tag, - (void*)remote_completion, ret); - return ret; + return LCI_putmac(ep, buffer, rank, tag, remote_completion, NULL, NULL); } LCI_error_t LCI_putl(LCI_endpoint_t ep, LCI_lbuffer_t local_buffer, diff --git a/lci/runtime/2sided_primitive.c b/lci/runtime/2sided_primitive.c index 4e550b02..f74a4231 100644 --- a/lci/runtime/2sided_primitive.c +++ b/lci/runtime/2sided_primitive.c @@ -35,9 +35,7 @@ LCI_error_t LCI_sendmc(LCI_endpoint_t ep, LCI_mbuffer_t buffer, int rank, LCII_MAKE_PROTO(ep->gid, LCI_MSG_MEDIUM, tag)); if (ret == LCI_OK && is_user_provided_packet) { LCII_packet_t* packet = LCII_mbuffer2packet(buffer); - packet->context.poolid = (buffer.length > LCI_PACKET_RETURN_THRESHOLD) - ? lc_pool_get_local(ep->pkpool) - : -1; + packet->context.poolid = -1; LCII_free_packet(packet); } } else {