1
1
mirror of https://github.com/systemd/systemd-stable.git synced 2024-12-23 17:34:00 +03:00

resolve: call dns_stream_take_read_packet() in on_stream_io()

As dns_stream_take_read_packet() is called only in on_packet callbacks,
and all on_packet callbacks call it.

(cherry picked from commit 624f907ea9)
This commit is contained in:
Yu Watanabe 2022-01-28 08:57:05 +09:00
parent fe4c208c98
commit b2f82f643a
6 changed files with 31 additions and 39 deletions

View File

@ -281,6 +281,22 @@ static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
return dns_stream_complete(s, ETIMEDOUT); return dns_stream_complete(s, ETIMEDOUT);
} }
static DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
assert(s);
if (!s->read_packet)
return NULL;
if (s->n_read < sizeof(s->read_size))
return NULL;
if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
return NULL;
s->n_read = 0;
return TAKE_PTR(s->read_packet);
}
static int on_stream_io_impl(DnsStream *s, uint32_t revents) { static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
bool progressed = false; bool progressed = false;
int r; int r;
@ -413,9 +429,10 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
/* Are we done? If so, call the packet handler and re-enable EPOLLIN for the /* Are we done? If so, call the packet handler and re-enable EPOLLIN for the
* event source if necessary. */ * event source if necessary. */
if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) { _cleanup_(dns_packet_unrefp) DnsPacket *p = dns_stream_take_read_packet(s);
if (p) {
assert(s->on_packet); assert(s->on_packet);
r = s->on_packet(s); r = s->on_packet(s, p);
if (r < 0) if (r < 0)
return r; return r;
@ -520,7 +537,7 @@ int dns_stream_new(
DnsProtocol protocol, DnsProtocol protocol,
int fd, int fd,
const union sockaddr_union *tfo_address, const union sockaddr_union *tfo_address,
int (on_packet)(DnsStream*), int (on_packet)(DnsStream*, DnsPacket*),
int (complete)(DnsStream*, int), /* optional */ int (complete)(DnsStream*, int), /* optional */
usec_t connect_timeout_usec) { usec_t connect_timeout_usec) {
@ -604,22 +621,6 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
return dns_stream_update_io(s); return dns_stream_update_io(s);
} }
DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
assert(s);
if (!s->read_packet)
return NULL;
if (s->n_read < sizeof(s->read_size))
return NULL;
if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
return NULL;
s->n_read = 0;
return TAKE_PTR(s->read_packet);
}
void dns_stream_detach(DnsStream *s) { void dns_stream_detach(DnsStream *s) {
assert(s); assert(s);

View File

@ -78,7 +78,7 @@ struct DnsStream {
size_t n_written, n_read; size_t n_written, n_read;
OrderedSet *write_queue; OrderedSet *write_queue;
int (*on_packet)(DnsStream *s); int (*on_packet)(DnsStream *s, DnsPacket *p);
int (*complete)(DnsStream *s, int error); int (*complete)(DnsStream *s, int error);
LIST_HEAD(DnsTransaction, transactions); /* when used by the transaction logic */ LIST_HEAD(DnsTransaction, transactions); /* when used by the transaction logic */
@ -100,7 +100,7 @@ int dns_stream_new(
DnsProtocol protocol, DnsProtocol protocol,
int fd, int fd,
const union sockaddr_union *tfo_address, const union sockaddr_union *tfo_address,
int (on_packet)(DnsStream*), int (on_packet)(DnsStream*, DnsPacket*),
int (complete)(DnsStream*, int), /* optional */ int (complete)(DnsStream*, int), /* optional */
usec_t connect_timeout_usec); usec_t connect_timeout_usec);
#if ENABLE_DNS_OVER_TLS #if ENABLE_DNS_OVER_TLS
@ -123,6 +123,4 @@ static inline bool DNS_STREAM_QUEUED(DnsStream *s) {
return !!s->write_packet; return !!s->write_packet;
} }
DnsPacket *dns_stream_take_read_packet(DnsStream *s);
void dns_stream_detach(DnsStream *s); void dns_stream_detach(DnsStream *s);

View File

@ -1037,12 +1037,9 @@ static int on_dns_stub_packet_extra(sd_event_source *s, int fd, uint32_t revents
return on_dns_stub_packet_internal(s, fd, revents, l->manager, l); return on_dns_stub_packet_internal(s, fd, revents, l->manager, l);
} }
static int on_dns_stub_stream_packet(DnsStream *s) { static int on_dns_stub_stream_packet(DnsStream *s, DnsPacket *p) {
_cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
assert(s); assert(s);
assert(s->manager);
p = dns_stream_take_read_packet(s);
assert(p); assert(p);
if (dns_packet_validate_query(p) > 0) { if (dns_packet_validate_query(p) > 0) {

View File

@ -644,14 +644,12 @@ static int on_stream_complete(DnsStream *s, int error) {
return 0; return 0;
} }
static int on_stream_packet(DnsStream *s) { static int on_stream_packet(DnsStream *s, DnsPacket *p) {
_cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
DnsTransaction *t; DnsTransaction *t;
assert(s); assert(s);
assert(s->manager);
/* Take ownership of packet to be able to receive new packets */ assert(p);
assert_se(p = dns_stream_take_read_packet(s));
t = hashmap_get(s->manager->dns_transactions, UINT_TO_PTR(DNS_PACKET_ID(p))); t = hashmap_get(s->manager->dns_transactions, UINT_TO_PTR(DNS_PACKET_ID(p)));
if (t && t->stream == s) /* Validate that the stream we got this on actually is the stream the if (t && t->stream == s) /* Validate that the stream we got this on actually is the stream the

View File

@ -277,13 +277,11 @@ int manager_llmnr_ipv6_udp_fd(Manager *m) {
return m->llmnr_ipv6_udp_fd = TAKE_FD(s); return m->llmnr_ipv6_udp_fd = TAKE_FD(s);
} }
static int on_llmnr_stream_packet(DnsStream *s) { static int on_llmnr_stream_packet(DnsStream *s, DnsPacket *p) {
_cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
DnsScope *scope; DnsScope *scope;
assert(s); assert(s);
assert(s->manager);
p = dns_stream_take_read_packet(s);
assert(p); assert(p);
scope = manager_find_scope(s->manager, p); scope = manager_find_scope(s->manager, p);

View File

@ -194,9 +194,9 @@ static const size_t MAX_RECEIVED_PACKETS = 2;
static DnsPacket *received_packets[2] = {}; static DnsPacket *received_packets[2] = {};
static size_t n_received_packets = 0; static size_t n_received_packets = 0;
static int on_stream_packet(DnsStream *stream) { static int on_stream_packet(DnsStream *stream, DnsPacket *p) {
assert_se(n_received_packets < MAX_RECEIVED_PACKETS); assert_se(n_received_packets < MAX_RECEIVED_PACKETS);
assert_se(received_packets[n_received_packets++] = dns_stream_take_read_packet(stream)); assert_se(received_packets[n_received_packets++] = dns_packet_ref(p));
return 0; return 0;
} }