diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c index faf5e26ba46..8c6f217ad96 100644 --- a/src/resolve/resolved-dns-stream.c +++ b/src/resolve/resolved-dns-stream.c @@ -280,7 +280,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use #if ENABLE_DNS_OVER_TLS if (s->encrypted) { - r = dnstls_stream_on_io(s); + r = dnstls_stream_on_io(s, revents); if (r == DNSTLS_STREAM_CLOSED) return 0; diff --git a/src/resolve/resolved-dnstls-gnutls.c b/src/resolve/resolved-dnstls-gnutls.c index 5e6a899db81..820e1926fd5 100644 --- a/src/resolve/resolved-dnstls-gnutls.c +++ b/src/resolve/resolved-dnstls-gnutls.c @@ -77,7 +77,7 @@ void dnstls_stream_free(DnsStream *stream) { gnutls_deinit(stream->dnstls_data.session); } -int dnstls_stream_on_io(DnsStream *stream) { +int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) { int r; assert(stream); diff --git a/src/resolve/resolved-dnstls-openssl.c b/src/resolve/resolved-dnstls-openssl.c index d0a1bba7731..5dd77373370 100644 --- a/src/resolve/resolved-dnstls-openssl.c +++ b/src/resolve/resolved-dnstls-openssl.c @@ -13,31 +13,84 @@ DEFINE_TRIVIAL_CLEANUP_FUNC(SSL*, SSL_free); DEFINE_TRIVIAL_CLEANUP_FUNC(BIO*, BIO_free); +static int dnstls_flush_write_buffer(DnsStream *stream) { + ssize_t ss; + + assert(stream); + assert(stream->encrypted); + + if (stream->dnstls_data.write_buffer->length > 0) { + assert(stream->dnstls_data.write_buffer->data); + + struct iovec iov[1]; + iov[0].iov_base = stream->dnstls_data.write_buffer->data; + iov[0].iov_len = stream->dnstls_data.write_buffer->length; + ss = dns_stream_writev(stream, iov, 1, DNS_STREAM_WRITE_TLS_DATA); + if (ss < 0) { + if (ss == -EAGAIN) + stream->dnstls_events |= EPOLLOUT; + + return ss; + } else { + stream->dnstls_data.write_buffer->length -= ss; + stream->dnstls_data.write_buffer->data += ss; + + if (stream->dnstls_data.write_buffer->length > 0) { + stream->dnstls_events |= EPOLLOUT; + return -EAGAIN; + } + } + } + + return 0; +} + int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) { _cleanup_(SSL_freep) SSL *s = NULL; - _cleanup_(BIO_freep) BIO *b = NULL; + _cleanup_(BIO_freep) BIO *rb = NULL; + _cleanup_(BIO_freep) BIO *wb = NULL; + int r; + int error; assert(stream); assert(server); - b = BIO_new_socket(stream->fd, 0); - if (!b) + rb = BIO_new_socket(stream->fd, 0); + if (!rb) return -ENOMEM; + wb = BIO_new(BIO_s_mem()); + if (!wb) + return -ENOMEM; + + BIO_get_mem_ptr(wb, &stream->dnstls_data.write_buffer); + s = SSL_new(server->dnstls_data.ctx); if (!s) return -ENOMEM; SSL_set_connect_state(s); - SSL_set_bio(s, b, b); - b = NULL; + SSL_set_session(s, server->dnstls_data.session); + SSL_set_bio(s, TAKE_PTR(rb), TAKE_PTR(wb)); - /* DNS-over-TLS using OpenSSL doesn't support TCP Fast Open yet */ - connect(stream->fd, &stream->tfo_address.sa, stream->tfo_salen); - stream->tfo_salen = 0; + stream->dnstls_data.handshake = SSL_do_handshake(s); + if (stream->dnstls_data.handshake <= 0) { + error = SSL_get_error(s, stream->dnstls_data.handshake); + if (!IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { + char errbuf[256]; + + ERR_error_string_n(error, errbuf, sizeof(errbuf)); + log_debug("Failed to invoke SSL_do_handshake: %s", errbuf); + return -ECONNREFUSED; + } + } stream->encrypted = true; - stream->dnstls_events = EPOLLOUT; + + r = dnstls_flush_write_buffer(stream); + if (r < 0 && r != -EAGAIN) + return r; + stream->dnstls_data.ssl = TAKE_PTR(s); return 0; @@ -51,7 +104,7 @@ void dnstls_stream_free(DnsStream *stream) { SSL_free(stream->dnstls_data.ssl); } -int dnstls_stream_on_io(DnsStream *stream) { +int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) { int r; int error; @@ -59,14 +112,25 @@ int dnstls_stream_on_io(DnsStream *stream) { assert(stream->encrypted); assert(stream->dnstls_data.ssl); + /* Flush write buffer when requested by OpenSSL ss*/ + if ((revents & EPOLLOUT) && (stream->dnstls_events & EPOLLOUT)) { + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + } + if (stream->dnstls_data.shutdown) { r = SSL_shutdown(stream->dnstls_data.ssl); - if (r == 0) - return -EAGAIN; - else if (r < 0) { + if (r <= 0) { error = SSL_get_error(stream->dnstls_data.ssl, r); - if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { - stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + if (r == 0 || IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { + if (r < 0) + stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + return -EAGAIN; } else { char errbuf[256]; @@ -76,6 +140,10 @@ int dnstls_stream_on_io(DnsStream *stream) { } } + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + stream->dnstls_events = 0; stream->dnstls_data.shutdown = false; dns_stream_unref(stream); @@ -86,6 +154,10 @@ int dnstls_stream_on_io(DnsStream *stream) { error = SSL_get_error(stream->dnstls_data.ssl, stream->dnstls_data.handshake); if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + return -EAGAIN; } else { char errbuf[256]; @@ -97,6 +169,9 @@ int dnstls_stream_on_io(DnsStream *stream) { } stream->dnstls_events = 0; + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; } return 0; @@ -111,6 +186,16 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) { assert(stream->encrypted); assert(stream->dnstls_data.ssl); + if (stream->server) { + s = SSL_get1_session(stream->dnstls_data.ssl); + if (s) { + if (stream->server->dnstls_data.session) + SSL_SESSION_free(stream->server->dnstls_data.session); + + stream->server->dnstls_data.session = s; + } + } + if (error == ETIMEDOUT) { r = SSL_shutdown(stream->dnstls_data.ssl); if (r == 0) { @@ -118,11 +203,20 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) { stream->dnstls_data.shutdown = true; dns_stream_ref(stream); } + + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + return -EAGAIN; } else if (r < 0) { ssl_error = SSL_get_error(stream->dnstls_data.ssl, r); if (IN_SET(ssl_error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { stream->dnstls_events = ssl_error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + r = dnstls_flush_write_buffer(stream); + if (r < 0 && r != -EAGAIN) + return r; + if (!stream->dnstls_data.shutdown) { stream->dnstls_data.shutdown = true; dns_stream_ref(stream); @@ -135,6 +229,11 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) { log_debug("Failed to invoke SSL_shutdown: %s", errbuf); } } + + stream->dnstls_events = 0; + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; } return 0; @@ -155,6 +254,10 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { error = SSL_get_error(stream->dnstls_data.ssl, ss); if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + ss = -EAGAIN; } else { char errbuf[256]; @@ -166,6 +269,10 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { } stream->dnstls_events = 0; + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + return ss; } @@ -184,6 +291,12 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { error = SSL_get_error(stream->dnstls_data.ssl, ss); if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + + /* flush write buffer in cache of renegotiation */ + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + ss = -EAGAIN; } else { char errbuf[256]; @@ -195,6 +308,12 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { } stream->dnstls_events = 0; + + /* flush write buffer in cache of renegotiation */ + r = dnstls_flush_write_buffer(stream); + if (r < 0) + return r; + return ss; } @@ -213,4 +332,7 @@ void dnstls_server_free(DnsServer *server) { if (server->dnstls_data.ctx) SSL_CTX_free(server->dnstls_data.ctx); + + if (server->dnstls_data.session) + SSL_SESSION_free(server->dnstls_data.session); } diff --git a/src/resolve/resolved-dnstls-openssl.h b/src/resolve/resolved-dnstls-openssl.h index c92d2b2354a..c57bc1c57c9 100644 --- a/src/resolve/resolved-dnstls-openssl.h +++ b/src/resolve/resolved-dnstls-openssl.h @@ -11,10 +11,12 @@ struct DnsTlsServerData { SSL_CTX *ctx; + SSL_SESSION *session; }; struct DnsTlsStreamData { int handshake; bool shutdown; SSL *ssl; + BUF_MEM *write_buffer; }; diff --git a/src/resolve/resolved-dnstls.h b/src/resolve/resolved-dnstls.h index 52af3e9801c..fdd85eece6a 100644 --- a/src/resolve/resolved-dnstls.h +++ b/src/resolve/resolved-dnstls.h @@ -23,7 +23,7 @@ typedef struct DnsTlsStreamData DnsTlsStreamData; int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server); void dnstls_stream_free(DnsStream *stream); -int dnstls_stream_on_io(DnsStream *stream); +int dnstls_stream_on_io(DnsStream *stream, uint32_t revents); int dnstls_stream_shutdown(DnsStream *stream, int error); ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count); ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count);