Optimise lokkup_domain()

This commit is contained in:
Simon Kelley
2021-06-17 21:30:40 +01:00
parent 0276e0805b
commit 6860cf932b
5 changed files with 67 additions and 51 deletions

View File

@@ -564,7 +564,7 @@ struct randfd_list {
struct server { struct server {
int flags; u16 flags, domain_len;
char *domain; char *domain;
struct server *next; struct server *next;
int serial, arrayposn; int serial, arrayposn;
@@ -583,23 +583,23 @@ struct server {
#endif #endif
}; };
/* First three fields must match struct server in next three definitions.. */ /* First four fields must match struct server in next three definitions.. */
struct serv_addr4 { struct serv_addr4 {
int flags; u16 flags, domain_len;
char *domain; char *domain;
struct server *next; struct server *next;
struct in_addr addr; struct in_addr addr;
}; };
struct serv_addr6 { struct serv_addr6 {
int flags; u16 flags, domain_len;
char *domain; char *domain;
struct server *next; struct server *next;
struct in6_addr addr; struct in6_addr addr;
}; };
struct serv_local { struct serv_local {
int flags; u16 flags, domain_len;
char *domain; char *domain;
struct server *next; struct server *next;
}; };
@@ -1381,7 +1381,7 @@ void set_option_bool(unsigned int opt);
void reset_option_bool(unsigned int opt); void reset_option_bool(unsigned int opt);
struct hostsfile *expand_filelist(struct hostsfile *list); struct hostsfile *expand_filelist(struct hostsfile *list);
char *parse_server(char *arg, union mysockaddr *addr, char *parse_server(char *arg, union mysockaddr *addr,
union mysockaddr *source_addr, char *interface, int *flags); union mysockaddr *source_addr, char *interface, u16 *flags);
int option_read_dynfile(char *file, int flags); int option_read_dynfile(char *file, int flags);
/* forward.c */ /* forward.c */

View File

@@ -79,18 +79,18 @@ void build_server_array(void)
*/ */
int lookup_domain(char *qdomain, int flags, int *lowout, int *highout) int lookup_domain(char *qdomain, int flags, int *lowout, int *highout)
{ {
int rc, nodots, leading_dot = 1; int rc, crop_query, nodots, leading_dot = 1;
ssize_t qlen, maxlen; ssize_t qlen;
int try, high, low = 0; int try, high, low = 0;
int nlow = 0, nhigh = 0; int nlow = 0, nhigh = 0;
char *cp; char *cp;
int compares = 0;
/* may be no configured servers. */ /* may be no configured servers. */
if (daemon->serverarraysz == 0) if (daemon->serverarraysz == 0)
return 0; return 0;
maxlen = strlen(daemon->serverarray[0]->domain);
/* find query length and presence of '.' */ /* find query length and presence of '.' */
for (cp = qdomain, nodots = 1, qlen = 0; *cp; qlen++, cp++) for (cp = qdomain, nodots = 1, qlen = 0; *cp; qlen++, cp++)
if (*cp == '.') if (*cp == '.')
@@ -101,13 +101,8 @@ int lookup_domain(char *qdomain, int flags, int *lowout, int *highout)
if (qlen == 0 || flags & F_DNSSECOK) if (qlen == 0 || flags & F_DNSSECOK)
nodots = 0; nodots = 0;
/* No point trying to match more than the largest server domain */ /* account for leading dot */
if (qlen > maxlen) qlen++;
{
qdomain += qlen - maxlen;
qlen = maxlen;
leading_dot = 0;
}
/* Search shorter and shorter RHS substrings for a match */ /* Search shorter and shorter RHS substrings for a match */
while (qlen >= 0) while (qlen >= 0)
@@ -115,18 +110,23 @@ int lookup_domain(char *qdomain, int flags, int *lowout, int *highout)
/* Note that when we chop off a character, all the possible matches /* Note that when we chop off a character, all the possible matches
MUST be at a larger index than the nearest failing match with one more MUST be at a larger index than the nearest failing match with one more
character, since the array is sorted longest to smallest. Hence character, since the array is sorted longest to smallest. Hence
we don't reset low here. */ we don't reset low to zero here, we can go further below and crop the
search string to the size of the largest remaining server
when this match fails. */
high = daemon->serverarraysz; high = daemon->serverarraysz;
crop_query = 1;
/* binary search */ /* binary search */
do while (1)
{ {
try = (low + high)/2; try = (low + high)/2;
compares++;
if ((rc = order(qdomain, leading_dot, qlen, daemon->serverarray[try])) == 0) if ((rc = order(qdomain, leading_dot, qlen, daemon->serverarray[try])) == 0)
break; break;
if (rc < 0) if (rc < 0)
{ {
if (high == try) if (high == try)
break; break;
@@ -138,19 +138,14 @@ int lookup_domain(char *qdomain, int flags, int *lowout, int *highout)
break; break;
low = try; low = try;
} }
} };
while (low != high);
if (rc == 0) if (rc == 0)
{ {
/* We've matched a setting which says to use servers without a domain. /* We've matched a setting which says to use servers without a domain.
Continue the search with empty query (the last character gets stripped Continue the search with empty query */
by the loop. */
if (daemon->serverarray[try]->flags & SERV_USE_RESOLV) if (daemon->serverarray[try]->flags & SERV_USE_RESOLV)
{ crop_query = qlen;
qdomain += qlen - 1;
qlen = 1;
}
else else
{ {
/* We have a match, but it may only be (say) an IPv6 address, and /* We have a match, but it may only be (say) an IPv6 address, and
@@ -160,27 +155,50 @@ int lookup_domain(char *qdomain, int flags, int *lowout, int *highout)
break; break;
} }
} }
if (leading_dot)
leading_dot = 0;
else else
{ {
qlen--; /* try now points to the last domain that sorts before the query, so
qdomain++; we know that a substring of the query shorter than it is required to match, so
find the largest domain that's shorter than try. Note that just going to
try+1 is not optimal, consider searching bbb in (aaa,ccc,bb). try will point
to aaa, since ccc sorts after bbb, but the first domain that has a chance to
match is dd. So find the length of the first domain later than try which is
is shorter than it. */
ssize_t len, old = daemon->serverarray[try]->domain_len;
while (++try != daemon->serverarraysz)
{
if (old != (len = daemon->serverarray[try]->domain_len))
{
/* crop_query must be at least one always. */
if (qlen != len)
crop_query = qlen - len;
break;
}
}
} }
qlen -= crop_query;
if (leading_dot)
{
leading_dot = 0;
crop_query--;
}
qdomain += crop_query;
} }
printf("compares: %d\n", compares);
/* domain has no dots, and we have at least one server configured to handle such, /* domain has no dots, and we have at least one server configured to handle such,
These servers always sort to the very end of the array. These servers always sort to the very end of the array.
A configured server eg server=/lan/ will take precdence. */ A configured server eg server=/lan/ will take precdence. */
if (nodots && if (nodots &&
(daemon->serverarray[daemon->serverarraysz-1]->flags & SERV_FOR_NODOTS) && (daemon->serverarray[daemon->serverarraysz-1]->flags & SERV_FOR_NODOTS) &&
(nlow == nhigh || strlen(daemon->serverarray[nlow]->domain) == 0)) (nlow == nhigh || daemon->serverarray[nlow]->domain_len == 0))
filter_servers(daemon->serverarraysz-1, flags, &nlow, &nhigh); filter_servers(daemon->serverarraysz-1, flags, &nlow, &nhigh);
/* F_DOMAINSRV returns only domain-specific servers, so if we got to a /* F_DOMAINSRV returns only domain-specific servers, so if we got to a
general server, return empty set. */ general server, return empty set. */
if (nlow != nhigh && (flags & F_DOMAINSRV) && strlen(daemon->serverarray[nlow]->domain) == 0) if (nlow != nhigh && (flags & F_DOMAINSRV) && daemon->serverarray[nlow]->domain_len == 0)
nlow = nhigh; nlow = nhigh;
if (lowout) if (lowout)
@@ -382,10 +400,7 @@ static int order(char *qdomain, int leading_dot, size_t qlen, struct server *ser
if (serv->flags & SERV_FOR_NODOTS) if (serv->flags & SERV_FOR_NODOTS)
return -1; return -1;
if (leading_dot) dlen = serv->domain_len;
qlen++;
dlen = strlen(serv->domain);
if (qlen < dlen) if (qlen < dlen)
return 1; return 1;
@@ -401,14 +416,13 @@ static int order(char *qdomain, int leading_dot, size_t qlen, struct server *ser
static int order_servers(struct server *s1, struct server *s2) static int order_servers(struct server *s1, struct server *s2)
{ {
size_t dlen = strlen(s1->domain); /* need full comparison of dotless servers in
order_qsort() and filter_servers() */
/* need full comparison of dotless servers in if (s1->flags & SERV_FOR_NODOTS)
order_qsort() and filter_servers() */
if (s1->flags & SERV_FOR_NODOTS)
return (s2->flags & SERV_FOR_NODOTS) ? 0 : 1; return (s2->flags & SERV_FOR_NODOTS) ? 0 : 1;
return order(s1->domain, 0, dlen, s2); return order(s1->domain, 0, s1->domain_len, s2);
} }
static int order_qsort(const void *a, const void *b) static int order_qsort(const void *a, const void *b)

View File

@@ -153,9 +153,8 @@ static int domain_no_rebind(char *domain)
struct server *serv; struct server *serv;
int dlen = (int)strlen(domain); int dlen = (int)strlen(domain);
/* flags is misused to hold length of domain. */
for (serv = daemon->no_rebind; serv; serv = serv->next) for (serv = daemon->no_rebind; serv; serv = serv->next)
if (dlen >= serv->flags && strcmp(serv->domain, &domain[dlen - serv->flags]) == 0) if (dlen >= serv->domain_len && strcmp(serv->domain, &domain[dlen - serv->flags]) == 0)
return 1; return 1;
return 0; return 0;

View File

@@ -1617,7 +1617,7 @@ void add_update_server(int flags,
serv->flags = flags; serv->flags = flags;
serv->domain = domain_str; serv->domain = domain_str;
serv->domain_len = strlen(domain_str);
if (!(flags & SERV_IS_LOCAL)) if (!(flags & SERV_IS_LOCAL))
{ {

View File

@@ -812,7 +812,7 @@ static char *parse_mysockaddr(char *arg, union mysockaddr *addr)
return NULL; return NULL;
} }
char *parse_server(char *arg, union mysockaddr *addr, union mysockaddr *source_addr, char *interface, int *flags) char *parse_server(char *arg, union mysockaddr *addr, union mysockaddr *source_addr, char *interface, u16 *flags)
{ {
int source_port = 0, serv_port = NAMESERVER_PORT; int source_port = 0, serv_port = NAMESERVER_PORT;
char *portno, *source; char *portno, *source;
@@ -2617,7 +2617,7 @@ static int one_opt(int option, char *arg, char *errstr, char *gen_err, int comma
comma = split_chr(arg, '/'); comma = split_chr(arg, '/');
new = opt_malloc(sizeof(struct serv_local)); new = opt_malloc(sizeof(struct serv_local));
new->domain = opt_string_alloc(arg); new->domain = opt_string_alloc(arg);
new->flags = strlen(arg); new->domain_len = strlen(arg);
new->next = daemon->no_rebind; new->next = daemon->no_rebind;
daemon->no_rebind = new; daemon->no_rebind = new;
arg = comma; arg = comma;
@@ -2634,7 +2634,7 @@ static int one_opt(int option, char *arg, char *errstr, char *gen_err, int comma
size_t size; size_t size;
char *lastdomain = NULL, *domain = ""; char *lastdomain = NULL, *domain = "";
char *alloc_domain; char *alloc_domain;
int flags = 0; u16 flags = 0;
char *err; char *err;
struct in_addr addr4; struct in_addr addr4;
struct in6_addr addr6; struct in6_addr addr6;
@@ -2730,6 +2730,7 @@ static int one_opt(int option, char *arg, char *errstr, char *gen_err, int comma
} }
new->domain = alloc_domain; new->domain = alloc_domain;
new->domain_len = strlen(alloc_domain);
/* server=//1.2.3.4 is special. */ /* server=//1.2.3.4 is special. */
if (strlen(domain) == 0 && lastdomain) if (strlen(domain) == 0 && lastdomain)
@@ -2751,6 +2752,8 @@ static int one_opt(int option, char *arg, char *errstr, char *gen_err, int comma
new = opt_malloc(size); new = opt_malloc(size);
memcpy(new, last, size); memcpy(new, last, size);
new->domain = alloc_domain; new->domain = alloc_domain;
new->domain_len = strlen(alloc_domain);
if (flags & (SERV_USE_RESOLV | SERV_LITERAL_ADDRESS)) if (flags & (SERV_USE_RESOLV | SERV_LITERAL_ADDRESS))
{ {
new->next = daemon->local_domains; new->next = daemon->local_domains;