From b2f82f643a9c9609058ed877b3d722b3822d484c Mon Sep 17 00:00:00 2001 From: Yu Watanabe Date: Fri, 28 Jan 2022 08:57:05 +0900 Subject: [PATCH] resolve: call dns_stream_take_read_packet() in on_stream_io() As dns_stream_take_read_packet() is called only in on_packet callbacks, and all on_packet callbacks call it. (cherry picked from commit 624f907ea9a42930bffb343dd44fbb0e34746cb0) --- src/resolve/resolved-dns-stream.c | 39 +++++++++++++------------- src/resolve/resolved-dns-stream.h | 6 ++-- src/resolve/resolved-dns-stub.c | 7 ++--- src/resolve/resolved-dns-transaction.c | 8 ++---- src/resolve/resolved-llmnr.c | 6 ++-- src/resolve/test-resolved-stream.c | 4 +-- 6 files changed, 31 insertions(+), 39 deletions(-) diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c index bdf46170d1..1b2db51212 100644 --- a/src/resolve/resolved-dns-stream.c +++ b/src/resolve/resolved-dns-stream.c @@ -281,6 +281,22 @@ static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) { return dns_stream_complete(s, ETIMEDOUT); } +static DnsPacket *dns_stream_take_read_packet(DnsStream *s) { + assert(s); + + if (!s->read_packet) + return NULL; + + if (s->n_read < sizeof(s->read_size)) + return NULL; + + if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) + return NULL; + + s->n_read = 0; + return TAKE_PTR(s->read_packet); +} + static int on_stream_io_impl(DnsStream *s, uint32_t revents) { bool progressed = false; int r; @@ -413,9 +429,10 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { /* Are we done? If so, call the packet handler and re-enable EPOLLIN for the * event source if necessary. */ - if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) { + _cleanup_(dns_packet_unrefp) DnsPacket *p = dns_stream_take_read_packet(s); + if (p) { assert(s->on_packet); - r = s->on_packet(s); + r = s->on_packet(s, p); if (r < 0) return r; @@ -520,7 +537,7 @@ int dns_stream_new( DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address, - int (on_packet)(DnsStream*), + int (on_packet)(DnsStream*, DnsPacket*), int (complete)(DnsStream*, int), /* optional */ usec_t connect_timeout_usec) { @@ -604,22 +621,6 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) { return dns_stream_update_io(s); } -DnsPacket *dns_stream_take_read_packet(DnsStream *s) { - assert(s); - - if (!s->read_packet) - return NULL; - - if (s->n_read < sizeof(s->read_size)) - return NULL; - - if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) - return NULL; - - s->n_read = 0; - return TAKE_PTR(s->read_packet); -} - void dns_stream_detach(DnsStream *s) { assert(s); diff --git a/src/resolve/resolved-dns-stream.h b/src/resolve/resolved-dns-stream.h index 548b2edc9e..fedbab2da2 100644 --- a/src/resolve/resolved-dns-stream.h +++ b/src/resolve/resolved-dns-stream.h @@ -78,7 +78,7 @@ struct DnsStream { size_t n_written, n_read; OrderedSet *write_queue; - int (*on_packet)(DnsStream *s); + int (*on_packet)(DnsStream *s, DnsPacket *p); int (*complete)(DnsStream *s, int error); LIST_HEAD(DnsTransaction, transactions); /* when used by the transaction logic */ @@ -100,7 +100,7 @@ int dns_stream_new( DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address, - int (on_packet)(DnsStream*), + int (on_packet)(DnsStream*, DnsPacket*), int (complete)(DnsStream*, int), /* optional */ usec_t connect_timeout_usec); #if ENABLE_DNS_OVER_TLS @@ -123,6 +123,4 @@ static inline bool DNS_STREAM_QUEUED(DnsStream *s) { return !!s->write_packet; } -DnsPacket *dns_stream_take_read_packet(DnsStream *s); - void dns_stream_detach(DnsStream *s); diff --git a/src/resolve/resolved-dns-stub.c b/src/resolve/resolved-dns-stub.c index 73fce6798e..9e34161eb3 100644 --- a/src/resolve/resolved-dns-stub.c +++ b/src/resolve/resolved-dns-stub.c @@ -1037,12 +1037,9 @@ static int on_dns_stub_packet_extra(sd_event_source *s, int fd, uint32_t revents return on_dns_stub_packet_internal(s, fd, revents, l->manager, l); } -static int on_dns_stub_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; - +static int on_dns_stub_stream_packet(DnsStream *s, DnsPacket *p) { assert(s); - - p = dns_stream_take_read_packet(s); + assert(s->manager); assert(p); if (dns_packet_validate_query(p) > 0) { diff --git a/src/resolve/resolved-dns-transaction.c b/src/resolve/resolved-dns-transaction.c index 20d257bbf3..f937f9f7b5 100644 --- a/src/resolve/resolved-dns-transaction.c +++ b/src/resolve/resolved-dns-transaction.c @@ -644,14 +644,12 @@ static int on_stream_complete(DnsStream *s, int error) { return 0; } -static int on_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; +static int on_stream_packet(DnsStream *s, DnsPacket *p) { DnsTransaction *t; assert(s); - - /* Take ownership of packet to be able to receive new packets */ - assert_se(p = dns_stream_take_read_packet(s)); + assert(s->manager); + assert(p); t = hashmap_get(s->manager->dns_transactions, UINT_TO_PTR(DNS_PACKET_ID(p))); if (t && t->stream == s) /* Validate that the stream we got this on actually is the stream the diff --git a/src/resolve/resolved-llmnr.c b/src/resolve/resolved-llmnr.c index 150cbab186..b4e551c219 100644 --- a/src/resolve/resolved-llmnr.c +++ b/src/resolve/resolved-llmnr.c @@ -277,13 +277,11 @@ int manager_llmnr_ipv6_udp_fd(Manager *m) { return m->llmnr_ipv6_udp_fd = TAKE_FD(s); } -static int on_llmnr_stream_packet(DnsStream *s) { - _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL; +static int on_llmnr_stream_packet(DnsStream *s, DnsPacket *p) { DnsScope *scope; assert(s); - - p = dns_stream_take_read_packet(s); + assert(s->manager); assert(p); scope = manager_find_scope(s->manager, p); diff --git a/src/resolve/test-resolved-stream.c b/src/resolve/test-resolved-stream.c index 76467629fb..8a01460a0e 100644 --- a/src/resolve/test-resolved-stream.c +++ b/src/resolve/test-resolved-stream.c @@ -194,9 +194,9 @@ static const size_t MAX_RECEIVED_PACKETS = 2; static DnsPacket *received_packets[2] = {}; static size_t n_received_packets = 0; -static int on_stream_packet(DnsStream *stream) { +static int on_stream_packet(DnsStream *stream, DnsPacket *p) { assert_se(n_received_packets < MAX_RECEIVED_PACKETS); - assert_se(received_packets[n_received_packets++] = dns_stream_take_read_packet(stream)); + assert_se(received_packets[n_received_packets++] = dns_packet_ref(p)); return 0; }