diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index a8c0296317..055fca268b 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -59,6 +59,19 @@ struct _virNetSocket { virSocketAddr remoteAddr; char *localAddrStr; char *remoteAddrStr; + + virNetTLSSessionPtr tlsSession; +#if HAVE_SASL + virNetSASLSessionPtr saslSession; + + const char *saslDecoded; + size_t saslDecodedLength; + size_t saslDecodedOffset; + + const char *saslEncoded; + size_t saslEncodedLength; + size_t saslEncodedOffset; +#endif }; @@ -417,7 +430,7 @@ error: } -#if HAVE_SYS_UN_H +#ifdef HAVE_SYS_UN_H int virNetSocketNewConnectUNIX(const char *path, bool spawnDaemon, const char *binary, @@ -624,6 +637,14 @@ void virNetSocketFree(virNetSocketPtr sock) unlink(sock->localAddr.data.un.sun_path); #endif + /* Make sure it can't send any more I/O during shutdown */ + if (sock->tlsSession) + virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL); + virNetTLSSessionFree(sock->tlsSession); +#if HAVE_SASL + virNetSASLSessionFree(sock->saslSession); +#endif + VIR_FORCE_CLOSE(sock->fd); VIR_FORCE_CLOSE(sock->errfd); @@ -709,17 +730,77 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock) return sock->remoteAddrStr; } -ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) + +static ssize_t virNetSocketTLSSessionWrite(const char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; + return write(sock->fd, buf, len); +} + + +static ssize_t virNetSocketTLSSessionRead(char *buf, + size_t len, + void *opaque) +{ + virNetSocketPtr sock = opaque; + return read(sock->fd, buf, len); +} + + +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess) +{ + virNetTLSSessionFree(sock->tlsSession); + sock->tlsSession = sess; + virNetTLSSessionSetIOCallbacks(sess, + virNetSocketTLSSessionWrite, + virNetSocketTLSSessionRead, + sock); + virNetTLSSessionRef(sess); +} + + +#if HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess) +{ + virNetSASLSessionFree(sock->saslSession); + sock->saslSession = sess; + virNetSASLSessionRef(sess); +} +#endif + + +bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED) +{ +#if HAVE_SASL + if (sock->saslDecoded) + return true; +#endif + return false; +} + + +static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len) { char *errout = NULL; ssize_t ret; reread: - ret = read(sock->fd, buf, len); + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionRead(sock->tlsSession, buf, len); + } else { + ret = read(sock->fd, buf, len); + } if ((ret < 0) && (errno == EINTR)) goto reread; if ((ret < 0) && (errno == EAGAIN)) return 0; + if (ret <= 0 && sock->errfd != -1 && virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 && @@ -751,11 +832,17 @@ reread: return ret; } -ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) +static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len) { ssize_t ret; rewrite: - ret = write(sock->fd, buf, len); + if (sock->tlsSession && + virNetTLSSessionGetHandshakeStatus(sock->tlsSession) == + VIR_NET_TLS_HANDSHAKE_COMPLETE) { + ret = virNetTLSSessionWrite(sock->tlsSession, buf, len); + } else { + ret = write(sock->fd, buf, len); + } if (ret < 0) { if (errno == EINTR) @@ -777,6 +864,127 @@ rewrite: } +#if HAVE_SASL +static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len) +{ + ssize_t got; + + /* Need to read some more data off the wire */ + if (sock->saslDecoded == NULL) { + ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession); + char *encoded; + if (VIR_ALLOC_N(encoded, encodedLen) < 0) { + virReportOOMError(); + return -1; + } + encodedLen = virNetSocketReadWire(sock, encoded, encodedLen); + + if (encodedLen <= 0) { + VIR_FREE(encoded); + return encodedLen; + } + + if (virNetSASLSessionDecode(sock->saslSession, + encoded, encodedLen, + &sock->saslDecoded, &sock->saslDecodedLength) < 0) { + VIR_FREE(encoded); + return -1; + } + VIR_FREE(encoded); + + sock->saslDecodedOffset = 0; + } + + /* Some buffered decoded data to return now */ + got = sock->saslDecodedLength - sock->saslDecodedOffset; + + if (len > got) + len = got; + + memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len); + sock->saslDecodedOffset += len; + + if (sock->saslDecodedOffset == sock->saslDecodedLength) { + sock->saslDecoded = NULL; + sock->saslDecodedOffset = sock->saslDecodedLength = 0; + } + + return len; +} + + +static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len) +{ + int ret; + size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession); + + /* SASL doesn't necessarily let us send the whole + buffer at once */ + if (tosend > len) + tosend = len; + + /* Not got any pending encoded data, so we need to encode raw stuff */ + if (sock->saslEncoded == NULL) { + if (virNetSASLSessionEncode(sock->saslSession, + buf, tosend, + &sock->saslEncoded, + &sock->saslEncodedLength) < 0) + return -1; + + sock->saslEncodedOffset = 0; + } + + /* Send some of the encoded stuff out on the wire */ + ret = virNetSocketWriteWire(sock, + sock->saslEncoded + sock->saslEncodedOffset, + sock->saslEncodedLength - sock->saslEncodedOffset); + + if (ret <= 0) + return ret; /* -1 error, 0 == egain */ + + /* Note how much we sent */ + sock->saslEncodedOffset += ret; + + /* Sent all encoded, so update raw buffer to indicate completion */ + if (sock->saslEncodedOffset == sock->saslEncodedLength) { + sock->saslEncoded = NULL; + sock->saslEncodedOffset = sock->saslEncodedLength = 0; + + /* Mark as complete, so caller detects completion */ + return tosend; + } else { + /* Still have stuff pending in saslEncoded buffer. + * Pretend to caller that we didn't send any yet. + * The caller will then retry with same buffer + * shortly, which lets us finish saslEncoded. + */ + return 0; + } +} +#endif + + +ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len) +{ +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketReadSASL(sock, buf, len); + else +#endif + return virNetSocketReadWire(sock, buf, len); +} + +ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len) +{ +#if HAVE_SASL + if (sock->saslSession) + return virNetSocketWriteSASL(sock, buf, len); + else +#endif + return virNetSocketWriteWire(sock, buf, len); +} + + int virNetSocketListen(virNetSocketPtr sock) { if (listen(sock->fd, 30) < 0) { diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index 218fe8f16f..59ff28824f 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -26,6 +26,10 @@ # include "network.h" # include "command.h" +# include "virnettlscontext.h" +# ifdef HAVE_SASL +# include "virnetsaslcontext.h" +# endif typedef struct _virNetSocket virNetSocket; typedef virNetSocket *virNetSocketPtr; @@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock, ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len); ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len); +void virNetSocketSetTLSSession(virNetSocketPtr sock, + virNetTLSSessionPtr sess); +# ifdef HAVE_SASL +void virNetSocketSetSASLSession(virNetSocketPtr sock, + virNetSASLSessionPtr sess); +# endif +bool virNetSocketHasCachedData(virNetSocketPtr sock); void virNetSocketFree(virNetSocketPtr sock); const char *virNetSocketLocalAddrString(virNetSocketPtr sock);