diff --git a/daemon/libvirtd.aug b/daemon/libvirtd.aug index ce00db5394..9d78bd7fe1 100644 --- a/daemon/libvirtd.aug +++ b/daemon/libvirtd.aug @@ -66,6 +66,10 @@ module Libvirtd = let auditing_entry = int_entry "audit_level" | bool_entry "audit_logging" + let keepalive_entry = int_entry "keepalive_interval" + | int_entry "keepalive_count" + | bool_entry "keepalive_required" + (* Each enty in the config is one of the following three ... *) let entry = network_entry | sock_acl_entry @@ -75,6 +79,7 @@ module Libvirtd = | processing_entry | logging_entry | auditing_entry + | keepalive_entry let comment = [ label "#comment" . del /#[ \t]*/ "# " . store /([^ \t\n][^\n]*)?/ . del /\n/ "\n" ] let empty = [ label "#empty" . eol ] diff --git a/daemon/libvirtd.c b/daemon/libvirtd.c index 5e1fc965e5..d7a03d71c5 100644 --- a/daemon/libvirtd.c +++ b/daemon/libvirtd.c @@ -146,6 +146,10 @@ struct daemonConfig { int audit_level; int audit_logging; + + int keepalive_interval; + unsigned int keepalive_count; + int keepalive_required; }; enum { @@ -899,6 +903,10 @@ daemonConfigNew(bool privileged ATTRIBUTE_UNUSED) data->audit_level = 1; data->audit_logging = 0; + data->keepalive_interval = 5; + data->keepalive_count = 5; + data->keepalive_required = 0; + localhost = virGetHostname(NULL); if (localhost == NULL) { /* we couldn't resolve the hostname; assume that we are @@ -1062,6 +1070,10 @@ daemonConfigLoad(struct daemonConfig *data, GET_CONF_STR (conf, filename, log_outputs); GET_CONF_INT (conf, filename, log_buffer_size); + GET_CONF_INT (conf, filename, keepalive_interval); + GET_CONF_INT (conf, filename, keepalive_count); + GET_CONF_INT (conf, filename, keepalive_required); + virConfFree (conf); return 0; @@ -1452,6 +1464,9 @@ int main(int argc, char **argv) { config->max_workers, config->prio_workers, config->max_clients, + config->keepalive_interval, + config->keepalive_count, + !!config->keepalive_required, config->mdns_adv ? config->mdns_name : NULL, use_polkit_dbus, remoteClientInitHook))) { diff --git a/daemon/libvirtd.conf b/daemon/libvirtd.conf index da3983ecef..f218454dc9 100644 --- a/daemon/libvirtd.conf +++ b/daemon/libvirtd.conf @@ -366,3 +366,28 @@ # it with the output of the 'uuidgen' command and then # uncomment this entry #host_uuid = "00000000-0000-0000-0000-000000000000" + +################################################################### +# Keepalive protocol: +# This allows libvirtd to detect broken client connections or even +# dead client. A keepalive message is sent to a client after +# keepalive_interval seconds of inactivity to check if the client is +# still responding; keepalive_count is a maximum number of keepalive +# messages that are allowed to be sent to the client without getting +# any response before the connection is considered broken. In other +# words, the connection is automatically closed approximately after +# keepalive_interval * (keepalive_count + 1) seconds since the last +# message received from the client. If keepalive_interval is set to +# -1, libvirtd will never send keepalive requests; however clients +# can still send them and the deamon will send responses. When +# keepalive_count is set to 0, connections will be automatically +# closed after keepalive_interval seconds of inactivity without +# sending any keepalive messages. +# +#keepalive_interval = 5 +#keepalive_count = 5 +# +# If set to 1, libvirtd will refuse to talk to clients that do not +# support keepalive protocol. Defaults to 0. +# +#keepalive_required = 1 diff --git a/daemon/libvirtd.h b/daemon/libvirtd.h index ce787a47ff..c8d3ca225c 100644 --- a/daemon/libvirtd.h +++ b/daemon/libvirtd.h @@ -61,6 +61,7 @@ struct daemonClientPrivate { virConnectPtr conn; daemonClientStreamPtr streams; + bool keepalive_supported; }; # if HAVE_SASL diff --git a/daemon/remote.c b/daemon/remote.c index 97c953823f..429979754a 100644 --- a/daemon/remote.c +++ b/daemon/remote.c @@ -581,7 +581,7 @@ int remoteClientInitHook(virNetServerPtr srv ATTRIBUTE_UNUSED, /*----- Functions. -----*/ static int -remoteDispatchOpen(virNetServerPtr server ATTRIBUTE_UNUSED, +remoteDispatchOpen(virNetServerPtr server, virNetServerClientPtr client, virNetMessagePtr msg ATTRIBUTE_UNUSED, virNetMessageErrorPtr rerr, @@ -600,6 +600,12 @@ remoteDispatchOpen(virNetServerPtr server ATTRIBUTE_UNUSED, goto cleanup; } + if (virNetServerKeepAliveRequired(server) && !priv->keepalive_supported) { + virNetError(VIR_ERR_OPERATION_FAILED, "%s", + _("keepalive support is required to connect")); + goto cleanup; + } + name = args->name ? *args->name : NULL; /* If this connection arrived on a readonly socket, force @@ -3226,6 +3232,16 @@ static int remoteDispatchSupportsFeature( struct daemonClientPrivate *priv = virNetServerClientGetPrivateData(client); + /* This feature is checked before opening the connection, thus we must + * check it first. + */ + if (args->feature == VIR_DRV_FEATURE_PROGRAM_KEEPALIVE) { + if (virNetServerClientStartKeepAlive(client) < 0) + goto cleanup; + supported = 1; + goto done; + } + if (!priv->conn) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("connection not open")); goto cleanup; @@ -3242,6 +3258,7 @@ static int remoteDispatchSupportsFeature( break; } +done: ret->supported = supported; rv = 0; diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms index 687c3795cd..0b21cdc2b6 100644 --- a/src/libvirt_private.syms +++ b/src/libvirt_private.syms @@ -1262,6 +1262,7 @@ virNetServerAutoShutdown; virNetServerClose; virNetServerFree; virNetServerIsPrivileged; +virNetServerKeepAliveRequired; virNetServerNew; virNetServerQuit; virNetServerRef; @@ -1294,6 +1295,7 @@ virNetServerClientSendMessage; virNetServerClientSetCloseHook; virNetServerClientSetIdentity; virNetServerClientSetPrivateData; +virNetServerClientStartKeepAlive; # virnetserverprogram.h diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c index 6533b5aa74..f761e6ba9a 100644 --- a/src/rpc/virnetserver.c +++ b/src/rpc/virnetserver.c @@ -102,6 +102,10 @@ struct _virNetServer { size_t nclients_max; virNetServerClientPtr *clients; + int keepaliveInterval; + unsigned int keepaliveCount; + bool keepaliveRequired; + unsigned int quit :1; virNetTLSContextPtr tls; @@ -261,6 +265,9 @@ static int virNetServerDispatchNewClient(virNetServerServicePtr svc ATTRIBUTE_UN virNetServerDispatchNewMessage, srv); + virNetServerClientInitKeepAlive(client, srv->keepaliveInterval, + srv->keepaliveCount); + virNetServerUnlock(srv); return 0; @@ -300,6 +307,9 @@ virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t priority_workers, size_t max_clients, + int keepaliveInterval, + unsigned int keepaliveCount, + bool keepaliveRequired, const char *mdnsGroupName, bool connectDBus ATTRIBUTE_UNUSED, virNetServerClientInitHook clientInitHook) @@ -321,6 +331,9 @@ virNetServerPtr virNetServerNew(size_t min_workers, goto error; srv->nclients_max = max_clients; + srv->keepaliveInterval = keepaliveInterval; + srv->keepaliveCount = keepaliveCount; + srv->keepaliveRequired = keepaliveRequired; srv->sigwrite = srv->sigread = -1; srv->clientInitHook = clientInitHook; srv->privileged = geteuid() == 0 ? true : false; @@ -840,3 +853,12 @@ void virNetServerClose(virNetServerPtr srv) virNetServerUnlock(srv); } + +bool virNetServerKeepAliveRequired(virNetServerPtr srv) +{ + bool required; + virNetServerLock(srv); + required = srv->keepaliveRequired; + virNetServerUnlock(srv); + return required; +} diff --git a/src/rpc/virnetserver.h b/src/rpc/virnetserver.h index cc9d039adc..a04ffddab2 100644 --- a/src/rpc/virnetserver.h +++ b/src/rpc/virnetserver.h @@ -41,6 +41,9 @@ virNetServerPtr virNetServerNew(size_t min_workers, size_t max_workers, size_t priority_workers, size_t max_clients, + int keepaliveInterval, + unsigned int keepaliveCount, + bool keepaliveRequired, const char *mdnsGroupName, bool connectDBus, virNetServerClientInitHook clientInitHook); @@ -88,4 +91,6 @@ void virNetServerFree(virNetServerPtr srv); void virNetServerClose(virNetServerPtr srv); +bool virNetServerKeepAliveRequired(virNetServerPtr srv); + #endif diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index cf97b58854..cb07dd91ed 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -33,6 +33,7 @@ #include "virterror_internal.h" #include "memory.h" #include "threads.h" +#include "virkeepalive.h" #define VIR_FROM_THIS VIR_FROM_RPC #define virNetError(code, ...) \ @@ -100,6 +101,9 @@ struct _virNetServerClient void *privateData; virNetServerClientFreeFunc privateDataFreeFunc; virNetServerClientCloseFunc privateDataCloseFunc; + + virKeepAlivePtr keepalive; + int keepaliveFilter; }; @@ -213,15 +217,15 @@ static void virNetServerClientUpdateEvent(virNetServerClientPtr client) } -int virNetServerClientAddFilter(virNetServerClientPtr client, - virNetServerClientFilterFunc func, - void *opaque) +static int +virNetServerClientAddFilterLocked(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque) { virNetServerClientFilterPtr filter; + virNetServerClientFilterPtr *place; int ret = -1; - virNetServerClientLock(client); - if (VIR_ALLOC(filter) < 0) { virReportOOMError(); goto cleanup; @@ -231,22 +235,34 @@ int virNetServerClientAddFilter(virNetServerClientPtr client, filter->func = func; filter->opaque = opaque; - filter->next = client->filters; - client->filters = filter; + place = &client->filters; + while (*place) + place = &(*place)->next; + *place = filter; ret = filter->id; cleanup: + return ret; +} + +int virNetServerClientAddFilter(virNetServerClientPtr client, + virNetServerClientFilterFunc func, + void *opaque) +{ + int ret; + + virNetServerClientLock(client); + ret = virNetServerClientAddFilterLocked(client, func, opaque); virNetServerClientUnlock(client); return ret; } - -void virNetServerClientRemoveFilter(virNetServerClientPtr client, - int filterID) +static void +virNetServerClientRemoveFilterLocked(virNetServerClientPtr client, + int filterID) { virNetServerClientFilterPtr tmp, prev; - virNetServerClientLock(client); prev = NULL; tmp = client->filters; @@ -263,7 +279,13 @@ void virNetServerClientRemoveFilter(virNetServerClientPtr client, prev = tmp; tmp = tmp->next; } +} +void virNetServerClientRemoveFilter(virNetServerClientPtr client, + int filterID) +{ + virNetServerClientLock(client); + virNetServerClientRemoveFilterLocked(client, filterID); virNetServerClientUnlock(client); } @@ -337,6 +359,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, client->readonly = readonly; client->tlsCtxt = tls; client->nrequests_max = nrequests_max; + client->keepaliveFilter = -1; client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc, client, NULL); @@ -603,6 +626,7 @@ void virNetServerClientFree(virNetServerClientPtr client) void virNetServerClientClose(virNetServerClientPtr client) { virNetServerClientCloseFunc cf; + virKeepAlivePtr ka; virNetServerClientLock(client); VIR_DEBUG("client=%p refs=%d", client, client->refs); @@ -611,6 +635,20 @@ void virNetServerClientClose(virNetServerClientPtr client) return; } + if (client->keepaliveFilter >= 0) + virNetServerClientRemoveFilterLocked(client, client->keepaliveFilter); + + if (client->keepalive) { + virKeepAliveStop(client->keepalive); + ka = client->keepalive; + client->keepalive = NULL; + client->refs++; + virNetServerClientUnlock(client); + virKeepAliveFree(ka); + virNetServerClientLock(client); + client->refs--; + } + if (client->privateDataCloseFunc) { cf = client->privateDataCloseFunc; client->refs++; @@ -1066,6 +1104,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client, VIR_DEBUG("msg=%p proc=%d len=%zu offset=%zu", msg, msg->header.proc, msg->bufferLength, msg->bufferOffset); + virNetServerClientLock(client); msg->donefds = 0; @@ -1082,6 +1121,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client, } virNetServerClientUnlock(client); + return ret; } @@ -1095,3 +1135,84 @@ bool virNetServerClientNeedAuth(virNetServerClientPtr client) virNetServerClientUnlock(client); return need; } + + +static void +virNetServerClientKeepAliveDeadCB(void *opaque) +{ + virNetServerClientImmediateClose(opaque); +} + +static int +virNetServerClientKeepAliveSendCB(void *opaque, + virNetMessagePtr msg) +{ + return virNetServerClientSendMessage(opaque, msg); +} + +static void +virNetServerClientFreeCB(void *opaque) +{ + virNetServerClientFree(opaque); +} + +static int +virNetServerClientKeepAliveFilter(virNetServerClientPtr client, + virNetMessagePtr msg, + void *opaque ATTRIBUTE_UNUSED) +{ + if (virKeepAliveCheckMessage(client->keepalive, msg)) { + virNetMessageFree(msg); + client->nrequests--; + return 1; + } + + return 0; +} + +int +virNetServerClientInitKeepAlive(virNetServerClientPtr client, + int interval, + unsigned int count) +{ + virKeepAlivePtr ka; + int ret = -1; + + virNetServerClientLock(client); + + if (!(ka = virKeepAliveNew(interval, count, client, + virNetServerClientKeepAliveSendCB, + virNetServerClientKeepAliveDeadCB, + virNetServerClientFreeCB))) + goto cleanup; + /* keepalive object has a reference to client */ + client->refs++; + + client->keepaliveFilter = + virNetServerClientAddFilterLocked(client, + virNetServerClientKeepAliveFilter, + NULL); + if (client->keepaliveFilter < 0) + goto cleanup; + + client->keepalive = ka; + ka = NULL; + +cleanup: + virNetServerClientUnlock(client); + if (ka) + virKeepAliveStop(ka); + virKeepAliveFree(ka); + + return ret; +} + +int +virNetServerClientStartKeepAlive(virNetServerClientPtr client) +{ + int ret; + virNetServerClientLock(client); + ret = virKeepAliveStart(client->keepalive, 0, 0); + virNetServerClientUnlock(client); + return ret; +} diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h index bedb179e74..a201dca2fe 100644 --- a/src/rpc/virnetserverclient.h +++ b/src/rpc/virnetserverclient.h @@ -99,6 +99,13 @@ bool virNetServerClientWantClose(virNetServerClientPtr client); int virNetServerClientInit(virNetServerClientPtr client); +int virNetServerClientInitKeepAlive(virNetServerClientPtr client, + int interval, + unsigned int count); +bool virNetServerClientCheckKeepAlive(virNetServerClientPtr client, + virNetMessagePtr msg); +int virNetServerClientStartKeepAlive(virNetServerClientPtr client); + const char *virNetServerClientLocalAddrString(virNetServerClientPtr client); const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client);