diff --git a/source/lib/messaging/irpc.h b/source/lib/messaging/irpc.h index 5b4abc556fc..4e775bfe069 100644 --- a/source/lib/messaging/irpc.h +++ b/source/lib/messaging/irpc.h @@ -76,14 +76,18 @@ struct irpc_request { } async; }; +typedef void (*msg_callback_t)(struct messaging_context *msg, void *private, + uint32_t msg_type, uint32_t server_id, DATA_BLOB *data); struct messaging_context *messaging_init(TALLOC_CTX *mem_ctx, uint32_t server_id, struct event_context *ev); NTSTATUS messaging_send(struct messaging_context *msg, uint32_t server, uint32_t msg_type, DATA_BLOB *data); -void messaging_register(struct messaging_context *msg, void *private, - uint32_t msg_type, - void (*fn)(struct messaging_context *, void *, uint32_t, uint32_t, DATA_BLOB *)); +NTSTATUS messaging_register(struct messaging_context *msg, void *private, + uint32_t msg_type, + msg_callback_t fn); +NTSTATUS messaging_register_tmp(struct messaging_context *msg, void *private, + msg_callback_t fn, uint32_t *msg_type); struct messaging_context *messaging_init(TALLOC_CTX *mem_ctx, uint32_t server_id, struct event_context *ev); struct messaging_context *messaging_client_init(TALLOC_CTX *mem_ctx, diff --git a/source/lib/messaging/messaging.c b/source/lib/messaging/messaging.c index f3296c6938d..f8f998f5cf1 100644 --- a/source/lib/messaging/messaging.c +++ b/source/lib/messaging/messaging.c @@ -41,7 +41,9 @@ struct messaging_context { struct socket_context *sock; const char *base_path; const char *path; - struct dispatch_fn *dispatch; + struct dispatch_fn **dispatch; + uint32_t num_types; + struct idr_context *dispatch_tree; struct messaging_rec *pending; struct irpc_list *irpc; struct idr_context *idr; @@ -54,14 +56,13 @@ struct messaging_context { } event; }; -/* we have a linked list of dispatch handlers that this messaging - server can deal with */ +/* we have a linked list of dispatch handlers for each msg_type that + this messaging server can deal with */ struct dispatch_fn { struct dispatch_fn *next, *prev; uint32_t msg_type; void *private; - void (*fn)(struct messaging_context *msg, void *private, - uint32_t msg_type, uint32_t server_id, DATA_BLOB *data); + msg_callback_t fn; }; /* an individual message */ @@ -127,14 +128,22 @@ static char *messaging_path(struct messaging_context *msg, uint32_t server_id) static void messaging_dispatch(struct messaging_context *msg, struct messaging_rec *rec) { struct dispatch_fn *d, *next; - for (d=msg->dispatch;d;d=next) { + + /* temporary IDs use an idtree, the rest use a array of pointers */ + if (rec->header->msg_type >= MSG_TMP_BASE) { + d = idr_find(msg->dispatch_tree, rec->header->msg_type); + } else if (rec->header->msg_type < msg->num_types) { + d = msg->dispatch[rec->header->msg_type]; + } else { + d = NULL; + } + + for (; d; d = next) { + DATA_BLOB data; next = d->next; - if (d->msg_type == rec->header->msg_type) { - DATA_BLOB data; - data.data = rec->packet.data + sizeof(*rec->header); - data.length = rec->header->length; - d->fn(msg, d->private, d->msg_type, rec->header->from, &data); - } + data.data = rec->packet.data + sizeof(*rec->header); + data.length = rec->header->length; + d->fn(msg, d->private, d->msg_type, rec->header->from, &data); } rec->header->length = 0; } @@ -272,17 +281,61 @@ static void messaging_handler(struct event_context *ev, struct fd_event *fde, /* Register a dispatch function for a particular message type. */ -void messaging_register(struct messaging_context *msg, void *private, - uint32_t msg_type, - void (*fn)(struct messaging_context *, void *, uint32_t, uint32_t, DATA_BLOB *)) +NTSTATUS messaging_register(struct messaging_context *msg, void *private, + uint32_t msg_type, msg_callback_t fn) { struct dispatch_fn *d; - d = talloc(msg, struct dispatch_fn); + /* possibly expand dispatch array */ + if (msg_type >= msg->num_types) { + struct dispatch_fn **dp; + int i; + dp = talloc_realloc(msg, msg->dispatch, struct dispatch_fn *, msg_type+1); + NT_STATUS_HAVE_NO_MEMORY(dp); + msg->dispatch = dp; + for (i=msg->num_types;i<=msg_type;i++) { + msg->dispatch[i] = NULL; + } + msg->num_types = msg_type+1; + } + + + d = talloc(msg->dispatch, struct dispatch_fn); + NT_STATUS_HAVE_NO_MEMORY(d); d->msg_type = msg_type; d->private = private; d->fn = fn; - DLIST_ADD(msg->dispatch, d); + + DLIST_ADD(msg->dispatch[msg_type], d); + + return NT_STATUS_OK; +} + +/* + register a temporary message handler. The msg_type is allocated + above MSG_TMP_BASE +*/ +NTSTATUS messaging_register_tmp(struct messaging_context *msg, void *private, + msg_callback_t fn, uint32_t *msg_type) +{ + struct dispatch_fn *d; + int id; + + d = talloc_zero(msg->dispatch, struct dispatch_fn); + NT_STATUS_HAVE_NO_MEMORY(d); + d->private = private; + d->fn = fn; + + id = idr_get_new_above(msg->dispatch_tree, d, MSG_TMP_BASE, UINT16_MAX); + if (id == -1) { + talloc_free(d); + return NT_STATUS_TOO_MANY_CONTEXT_IDS; + } + + d->msg_type = (uint32_t)id; + (*msg_type) = d->msg_type; + + return NT_STATUS_OK; } /* @@ -290,16 +343,34 @@ void messaging_register(struct messaging_context *msg, void *private, */ void messaging_deregister(struct messaging_context *msg, uint32_t msg_type, void *private) { - struct dispatch_fn *d, *next; + struct dispatch_fn *d, *list, *next; - for (d = msg->dispatch; d; d = next) { + if (msg_type >= msg->num_types) { + list = idr_find(msg->dispatch_tree, msg_type); + } else { + list = msg->dispatch[msg_type]; + } + + if (list == NULL) { + return; + } + + for (d = list; d; d = next) { next = d->next; - if (d->msg_type == msg_type && - d->private == private) { - DLIST_REMOVE(msg->dispatch, d); + if (d->private == private) { + DLIST_REMOVE(list, d); talloc_free(d); } } + + /* the list base possibly changed */ + if (list == NULL) { + if (msg_type >= msg->num_types) { + idr_remove(msg->dispatch_tree, msg_type); + } else { + msg->dispatch[msg_type] = NULL; + } + } } @@ -397,7 +468,7 @@ struct messaging_context *messaging_init(TALLOC_CTX *mem_ctx, uint32_t server_id struct socket_address *path; char *dir; - msg = talloc(mem_ctx, struct messaging_context); + msg = talloc_zero(mem_ctx, struct messaging_context); if (msg == NULL) { return NULL; } @@ -411,15 +482,12 @@ struct messaging_context *messaging_init(TALLOC_CTX *mem_ctx, uint32_t server_id mkdir(dir, 0700); talloc_free(dir); - msg->base_path = smbd_tmp_path(msg, "messaging"); - msg->path = messaging_path(msg, server_id); - msg->server_id = server_id; - msg->dispatch = NULL; - msg->pending = NULL; - msg->idr = idr_init(msg); - msg->irpc = NULL; - msg->names = NULL; - msg->start_time = timeval_current(); + msg->base_path = smbd_tmp_path(msg, "messaging"); + msg->path = messaging_path(msg, server_id); + msg->server_id = server_id; + msg->idr = idr_init(msg); + msg->dispatch_tree = idr_init(msg); + msg->start_time = timeval_current(); status = socket_create("unix", SOCKET_TYPE_DGRAM, &msg->sock, 0); if (!NT_STATUS_IS_OK(status)) { diff --git a/source/lib/messaging/messaging.h b/source/lib/messaging/messaging.h index 86f5db2c171..5324c530eaa 100644 --- a/source/lib/messaging/messaging.h +++ b/source/lib/messaging/messaging.h @@ -34,4 +34,7 @@ struct messaging_context; #define MSG_PVFS_NOTIFY 7 #define MSG_NTVFS_OPLOCK_BREAK 8 +/* temporary messaging endpoints are allocated above this line */ +#define MSG_TMP_BASE 1000 + #endif diff --git a/source/torture/local/messaging.c b/source/torture/local/messaging.c index 77bce155bc5..70bfd090f00 100644 --- a/source/torture/local/messaging.c +++ b/source/torture/local/messaging.c @@ -25,13 +25,14 @@ #include "lib/messaging/irpc.h" #include "torture/torture.h" -enum {MY_PING=1000, MY_PONG, MY_EXIT}; + +static uint32_t msg_pong; static void ping_message(struct messaging_context *msg, void *private, uint32_t msg_type, uint32_t src, DATA_BLOB *data) { NTSTATUS status; - status = messaging_send(msg, src, MY_PONG, data); + status = messaging_send(msg, src, msg_pong, data); if (!NT_STATUS_IS_OK(status)) { printf("pong failed - %s\n", nt_errstr(status)); } @@ -64,6 +65,7 @@ static BOOL test_ping_speed(TALLOC_CTX *mem_ctx) BOOL ret = True; struct timeval tv; int timelimit = lp_parm_int(-1, "torture", "timelimit", 10); + uint32_t msg_ping, msg_exit; lp_set_cmdline("lock dir", "lockdir.tmp"); @@ -77,8 +79,8 @@ static BOOL test_ping_speed(TALLOC_CTX *mem_ctx) return False; } - messaging_register(msg_server_ctx, NULL, MY_PING, ping_message); - messaging_register(msg_server_ctx, mem_ctx, MY_EXIT, exit_message); + messaging_register_tmp(msg_server_ctx, NULL, ping_message, &msg_ping); + messaging_register_tmp(msg_server_ctx, mem_ctx, exit_message, &msg_exit); msg_client_ctx = messaging_init(mem_ctx, 2, ev); @@ -87,7 +89,7 @@ static BOOL test_ping_speed(TALLOC_CTX *mem_ctx) return False; } - messaging_register(msg_client_ctx, &pong_count, MY_PONG, pong_message); + messaging_register_tmp(msg_client_ctx, &pong_count, pong_message, &msg_pong); tv = timeval_current(); @@ -99,8 +101,8 @@ static BOOL test_ping_speed(TALLOC_CTX *mem_ctx) data.data = discard_const_p(uint8_t, "testing"); data.length = strlen((const char *)data.data); - status1 = messaging_send(msg_client_ctx, 1, MY_PING, &data); - status2 = messaging_send(msg_client_ctx, 1, MY_PING, NULL); + status1 = messaging_send(msg_client_ctx, 1, msg_ping, &data); + status2 = messaging_send(msg_client_ctx, 1, msg_ping, NULL); if (!NT_STATUS_IS_OK(status1)) { printf("msg1 failed - %s\n", nt_errstr(status1)); @@ -126,7 +128,7 @@ static BOOL test_ping_speed(TALLOC_CTX *mem_ctx) } printf("sending exit\n"); - messaging_send(msg_client_ctx, 1, MY_EXIT, NULL); + messaging_send(msg_client_ctx, 1, msg_exit, NULL); if (ping_count != pong_count) { printf("ping test failed! received %d, sent %d\n",