Remove query hashing in UDP DNS code path.

Now we're always saving the query, this is no longer necessary,
and allows the removal of a lot of quite hairy code.

Much more code removal to come, once the TCP code path is also purged.
This commit is contained in:
Simon Kelley
2024-11-02 15:33:37 +00:00
parent 5d6399b71c
commit ed6d29a784
2 changed files with 91 additions and 119 deletions

View File

@@ -774,9 +774,8 @@ struct dyndir {
#define FREC_DO_QUESTION 64
#define FREC_ADDED_PHEADER 128
#define FREC_TEST_PKTSZ 256
#define FREC_HAS_EXTRADATA 512
#define FREC_HAS_PHEADER 1024
#define FREC_GONE_TO_TCP 2048
#define FREC_HAS_PHEADER 512
#define FREC_GONE_TO_TCP 1024
#define HASH_SIZE 32 /* SHA-256 digest size */
@@ -796,8 +795,7 @@ struct frec {
time_t time;
u32 forward_timestamp;
int forward_delay;
unsigned char *hash[HASH_SIZE];
struct blockdata *stash; /* Saved reply, whilst we validate */
struct blockdata *stash; /* saved query or saved reply, whilst we validate */
size_t stash_len;
#ifdef HAVE_DNSSEC
int uid, class, work_counter, validate_counter;

View File

@@ -17,10 +17,8 @@
#include "dnsmasq.h"
static struct frec *get_new_frec(time_t now, struct server *serv, int force);
static struct frec *lookup_frec(unsigned short id, int fd, void *hash, int *firstp, int *lastp);
static struct frec *lookup_frec_by_query(void *hash, unsigned int flags, unsigned int flagmask);
static struct frec *lookup_frec(char *target, int class, int rrtype, int id, int flags, int flagmask);
#ifdef HAVE_DNSSEC
static struct frec *lookup_frec_dnssec(char *target, int class, int flags, struct dns_header *header);
static int tcp_key_recurse(time_t now, int status, struct dns_header *header, size_t n,
int class, char *name, char *keyname, struct server *server,
int have_mark, unsigned int mark, int *keycount, int *validatecount);
@@ -175,20 +173,22 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
unsigned int fwd_flags = 0;
int is_dnssec = forward && (forward->flags & (FREC_DNSKEY_QUERY | FREC_DS_QUERY));
struct server *master;
void *hash = hash_questions(header, plen, daemon->namebuff);
unsigned int gotname = extract_request(header, plen, daemon->namebuff, NULL);
unsigned char *oph = find_pseudoheader(header, plen, NULL, NULL, NULL, NULL);
unsigned int gotname;
int old_src = 0, old_reply = 0;
int first, last, start = 0;
int cacheable, forwarded = 0;
size_t edns0_len;
unsigned char *pheader;
unsigned char *pheader, *oph;
int ede = EDE_UNSET;
(void)do_bit;
unsigned short rrtype;
if (saved_question)
blockdata_retrieve(saved_question, plen, header);
gotname = extract_request(header, plen, daemon->namebuff, &rrtype);
oph = find_pseudoheader(header, plen, NULL, NULL, NULL, NULL);
if (header->hb4 & HB4_CD)
fwd_flags |= FREC_CHECKING_DISABLED;
if (ad_reqd)
@@ -211,9 +211,9 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
old_src = 1;
old_reply = 1;
}
else if ((forward = lookup_frec_by_query(hash, fwd_flags,
FREC_CHECKING_DISABLED | FREC_AD_QUESTION | FREC_DO_QUESTION |
FREC_HAS_PHEADER | FREC_DNSKEY_QUERY | FREC_DS_QUERY | FREC_NO_CACHE)))
else if (gotname && (forward = lookup_frec(daemon->namebuff, C_IN, (int)rrtype, fwd_flags, -1,
FREC_CHECKING_DISABLED | FREC_AD_QUESTION | FREC_DO_QUESTION |
FREC_HAS_PHEADER | FREC_DNSKEY_QUERY | FREC_DS_QUERY | FREC_NO_CACHE)))
{
struct frec_src *src;
@@ -272,7 +272,10 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
it's safe to wait for the reply from the first without
forwarding the second. */
if (difftime(now, forward->time) < 2)
return 0;
{
blockdata_free(saved_question);
return 0;
}
}
}
@@ -280,8 +283,8 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
if (!forward)
{
/* If the query is malformed, we can't forward it because
we can't get a reliable hash to recognise the answer. */
if (!hash)
we can't recognise the answer. */
if (!gotname)
{
flags = 0;
ede = EDE_INVALID_DATA;
@@ -351,8 +354,8 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
}
else
forward->stash = blockdata_alloc((char *)header, plen);
forward->stash_len = plen;
forward->stash_len = plen;
forward->frec_src.log_id = daemon->log_id;
forward->frec_src.source = *udpaddr;
forward->frec_src.orig_id = ntohs(header->id);
@@ -362,7 +365,6 @@ static int forward_query(int udpfd, union mysockaddr *udpaddr,
forward->frec_src.fd = udpfd;
forward->new_id = get_id();
header->id = htons(forward->new_id);
memcpy(forward->hash, hash, HASH_SIZE);
forward->forwardall = 0;
forward->flags = fwd_flags;
if (domain_no_rebind(daemon->namebuff))
@@ -623,7 +625,7 @@ int fast_retry(time_t now)
u32 millis = dnsmasq_milliseconds();
for (f = daemon->frec_list; f; f = f->next)
if (f->sentto && f->stash && difftime(now, f->time) < daemon->fast_retry_timeout)
if (f->sentto && difftime(now, f->time) < daemon->fast_retry_timeout)
{
#ifdef HAVE_DNSSEC
if (f->blocking_query || (f->flags & FREC_GONE_TO_TCP))
@@ -1002,7 +1004,7 @@ static void dnssec_validate(struct frec *forward, struct dns_header *header,
/* validate routines leave name of required record in daemon->keyname */
unsigned int flags = STAT_ISEQUAL(status, STAT_NEED_KEY) ? FREC_DNSKEY_QUERY : FREC_DS_QUERY;
if ((new = lookup_frec_dnssec(daemon->keyname, forward->class, flags, header)))
if ((new = lookup_frec(daemon->keyname, forward->class, -1, -1, flags, flags)))
{
/* This is tricky; it detects loops in the dependency
graph for DNSSEC validation, say validating A requires DS B
@@ -1039,7 +1041,6 @@ static void dnssec_validate(struct frec *forward, struct dns_header *header,
else
{
struct server *server;
void *hash;
size_t nn;
int serverind, fd;
struct randfd_list *rfds = NULL;
@@ -1051,7 +1052,6 @@ static void dnssec_validate(struct frec *forward, struct dns_header *header,
(nn = dnssec_generate_query(header, ((unsigned char *) header) + server->edns_pktsz,
daemon->keyname, forward->class,
STAT_ISEQUAL(status, STAT_NEED_KEY) ? T_DNSKEY : T_DS, server->edns_pktsz)) &&
(hash = hash_questions(header, nn, daemon->namebuff)) &&
(fd = allocate_rfd(&rfds, server)) != -1 &&
(new = get_new_frec(now, server, 1)))
{
@@ -1065,7 +1065,7 @@ static void dnssec_validate(struct frec *forward, struct dns_header *header,
new->sentto = server;
new->rfds = rfds;
new->frec_src.next = NULL;
new->flags &= ~(FREC_DNSKEY_QUERY | FREC_DS_QUERY | FREC_HAS_EXTRADATA);
new->flags &= ~(FREC_DNSKEY_QUERY | FREC_DS_QUERY);
new->flags |= flags;
new->forwardall = 0;
@@ -1080,7 +1080,6 @@ static void dnssec_validate(struct frec *forward, struct dns_header *header,
forward->stash_len = plen;
forward->stash = stash;
memcpy(new->hash, hash, HASH_SIZE);
new->new_id = get_id();
header->id = htons(new->new_id);
/* Save query for retransmission and de-dup */
@@ -1169,8 +1168,9 @@ void reply_query(int fd, time_t now)
socklen_t addrlen = sizeof(serveraddr);
ssize_t n = recvfrom(fd, daemon->packet, daemon->packet_buff_sz, 0, &serveraddr.sa, &addrlen);
struct server *server;
void *hash;
int first, last, c;
int first, last, serv, c, class, rrtype;
unsigned char *p;
struct randfd_list *fdl;
/* packet buffer overwritten */
daemon->srv_save = NULL;
@@ -1181,14 +1181,43 @@ void reply_query(int fd, time_t now)
header = (struct dns_header *)daemon->packet;
if (n < (int)sizeof(struct dns_header) || !(header->hb3 & HB3_QR))
if (n < (int)sizeof(struct dns_header) || !(header->hb3 & HB3_QR) || ntohs(header->qdcount) != 1)
return;
hash = hash_questions(header, n, daemon->namebuff);
p = (unsigned char *)(header+1);
if (!extract_name(header, n, &p, daemon->namebuff, 1, 4))
return; /* bad packet */
GETSHORT(rrtype, p);
GETSHORT(class, p);
if (!(forward = lookup_frec(ntohs(header->id), fd, hash, &first, &last)))
if (!(forward = lookup_frec(daemon->namebuff, class, rrtype, ntohs(header->id), 0, 0)))
return;
filter_servers(forward->sentto->arrayposn, F_SERVER, &first, &last);
/* Check that this arrived on the file descriptor we expected. */
/* sent from random port */
for (fdl = forward->rfds; fdl; fdl = fdl->next)
if (fdl->rfd->fd == fd)
break;
if (!fdl)
{
/* Sent to upstream from socket associated with a server.
Note we have to iterate over all the possible servers, since they may
have different bound sockets. */
for (serv = first; serv != last; serv++)
{
server = daemon->serverarray[serv];
if (server->sfd && server->sfd->fd == fd)
break;
if (serv == last)
return;
}
}
/* spoof check: answer must come from known server, also
we may have sent the same query to multiple servers from
the same local socket, and would like to know which one has answered. */
@@ -1224,16 +1253,11 @@ void reply_query(int fd, time_t now)
check_for_ignored_address(header, n))
return;
/* Note: if we send extra options in the EDNS0 header, we can't recreate
the query from the reply. */
if ((RCODE(header) == REFUSED || RCODE(header) == SERVFAIL) &&
forward->forwardall == 0 &&
!(forward->flags & FREC_HAS_EXTRADATA))
if ((RCODE(header) == REFUSED || RCODE(header) == SERVFAIL) && forward->forwardall == 0)
/* for broken servers, attempt to send to another one. */
{
unsigned char *udpsz;
unsigned short udp_size = PACKETSZ; /* default if no EDNS0 */
size_t nn = 0;
#ifdef HAVE_DNSSEC
/* The query MAY have got a good answer, and be awaiting
@@ -1245,18 +1269,14 @@ void reply_query(int fd, time_t now)
/* Get the saved query back. */
blockdata_retrieve(forward->stash, forward->stash_len, (void *)header);
nn = forward->stash_len;
/* UDP size already set in saved query. */
if (find_pseudoheader(header, (size_t)n, NULL, &udpsz, NULL, NULL))
if (find_pseudoheader(header, (size_t)forward->stash_len, NULL, &udpsz, NULL, NULL))
GETSHORT(udp_size, udpsz);
if (nn)
{
forward_query(-1, NULL, NULL, 0, header, nn, ((char *) header) + udp_size, now, forward,
forward->flags & FREC_AD_QUESTION, forward->flags & FREC_DO_QUESTION, 0, NULL);
return;
}
forward_query(-1, NULL, NULL, 0, header, forward->stash_len, ((char *) header) + udp_size, now, forward,
forward->flags & FREC_AD_QUESTION, forward->flags & FREC_DO_QUESTION, 0, NULL);
return;
}
/* If the answer is an error, keep the forward record in place in case
@@ -3059,85 +3079,39 @@ static void query_full(time_t now, char *domain)
}
}
static struct frec *lookup_frec(unsigned short id, int fd, void *hash, int *firstp, int *lastp)
static struct frec *lookup_frec(char *target, int class, int rrtype, int id, int flags, int flagmask)
{
struct frec *f;
struct server *s;
int first, last;
struct randfd_list *fdl;
struct dns_header *header;
if (hash)
for (f = daemon->frec_list; f; f = f->next)
if (f->sentto && f->new_id == id &&
(memcmp(hash, f->hash, HASH_SIZE) == 0))
{
filter_servers(f->sentto->arrayposn, F_SERVER, firstp, lastp);
for (f = daemon->frec_list; f; f = f->next)
if (f->sentto &&
(f->flags & flagmask) == flags &&
(f->new_id == id || id == -1) &&
(header = blockdata_retrieve(f->stash, f->stash_len, NULL)))
{
unsigned char *p = (unsigned char *)(header+1);
int hclass, hrrtype;
/* sent from random port */
for (fdl = f->rfds; fdl; fdl = fdl->next)
if (fdl->rfd->fd == fd)
return f;
if (extract_name(header, f->stash_len, &p, target, 0, 4) != 1)
continue;
/* Sent to upstream from socket associated with a server.
Note we have to iterate over all the possible servers, since they may
have different bound sockets. */
for (first = *firstp, last = *lastp; first != last; first++)
{
s = daemon->serverarray[first];
if (s->sfd && s->sfd->fd == fd)
return f;
}
}
GETSHORT(hrrtype, p);
GETSHORT(hclass, p);
return NULL;
}
/* type checked by flags for DNSSEC queries. */
if (rrtype != -1 && rrtype != hrrtype)
continue;
static struct frec *lookup_frec_by_query(void *hash, unsigned int flags, unsigned int flagmask)
{
struct frec *f;
if (class != hclass)
continue;
if (hash)
for (f = daemon->frec_list; f; f = f->next)
if (f->sentto &&
(f->flags & flagmask) == flags &&
memcmp(hash, f->hash, HASH_SIZE) == 0)
return f;
}
return NULL;
}
#ifdef HAVE_DNSSEC
/* DNSSEC frecs have the complete query in the block stash.
Search for an existing query using that. */
static struct frec *lookup_frec_dnssec(char *target, int class, int flags, struct dns_header *header)
{
struct frec *f;
for (f = daemon->frec_list; f; f = f->next)
if (f->sentto &&
(f->flags & flags) &&
blockdata_retrieve(f->stash, f->stash_len, (void *)header))
{
unsigned char *p = (unsigned char *)(header+1);
int hclass;
if (extract_name(header, f->stash_len, &p, target, 0, 4) != 1)
continue;
p += 2; /* type, known from flags */
GETSHORT(hclass, p);
if (class != hclass)
continue;
return f;
}
return NULL;
}
#endif
/* Send query packet again, if we can. */
void resend_query()
{