Implement DNS server with fake IP address replacements.

The IPv4 addresses in the answer will be hash-mapped to addresses in the pool.
This commit is contained in:
2025-04-03 10:06:12 +08:00
parent f409f72cec
commit 7a3b18b848
2 changed files with 786 additions and 0 deletions

22
Makefile Normal file
View File

@@ -0,0 +1,22 @@
CC=gcc
CFLAGS=-Wall
LDFLAGS=
DEPS=$(wildcard *.h)
SRC=$(wildcard src/*.c)
OBJ=$(patsubst src/%.c, build/%.o, $(SRC))
.PHONY: all clean
all: build/dotp
clean:
rm $(OBJ)
rm build/dotp
build/%.o: src/%.c $(DEPS)
mkdir -p build
$(CC) -c -o $@ $< $(CFLAGS)
build/dotp: $(OBJ)
mkdir -p build
$(CC) -o $@ $^ $(LDFLAGS) -lev -lnftables

764
src/main.c Normal file
View File

@@ -0,0 +1,764 @@
#include <arpa/inet.h>
#include <ctype.h>
#include <ev.h>
#include <getopt.h>
#include <malloc.h>
#include <netinet/in.h>
#include <nftables/libnftables.h>
#include <stdarg.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <strings.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#define MAX_MESSAGE_SIZE 0x200
#define DNS_FLAG_TC 0x200
#define DNS_TYPE_A 1
#define DNS_CLASS_IN 0x0001
#define NAT_TTL 60
typedef struct __attribute__((packed)) dns_header {
uint16_t id;
uint16_t flags;
uint16_t qd_count;
uint16_t an_count;
uint16_t ns_count;
uint16_t ar_count;
} dns_header_t;
typedef struct __attribute__((packed)) dns_answer_header {
uint16_t an_type;
uint16_t an_class;
uint32_t an_ttl;
uint16_t rd_len;
} dns_answer_header_t;
typedef struct domain_set {
struct domain_set *next;
struct domain_set *head;
uint8_t label_len;
char const *label;
} domain_set_t;
typedef struct domain_name {
struct domain_name *next;
ssize_t ptr;
domain_set_t *match;
} domain_name_t;
typedef struct domain_msg {
domain_set_t *match_root;
ssize_t len;
uint8_t *raw;
domain_name_t *name_head;
} domain_msg_t;
typedef struct ip_pool ip_pool_t;
typedef struct ip_nat {
ev_timer expire;
ip_pool_t *pool;
struct ip_nat *fake_next, *real_next;
uint32_t fake, real;
int dst_handle, src_handle;
} ip_nat_t;
typedef struct ip_pool {
uint32_t pf;
uint32_t pf_mask;
// Indexed by fake
size_t size;
ip_nat_t **fake;
ip_nat_t **real;
} ip_pool_t;
typedef struct server_ctx {
ev_io io;
domain_set_t *match_root;
ip_pool_t *ip_pool;
int fd;
struct sockaddr_in upstream_addr;
} server_ctx_t;
typedef struct client_ctx {
ev_io io;
ev_timer timer;
domain_set_t *match_root;
ip_pool_t *ip_pool;
int server_fd;
socklen_t client_addr_len;
struct sockaddr_storage client_addr;
int fd;
domain_msg_t msg;
} client_ctx_t;
static struct nft_ctx *nft_ctx;
char *malloc_sprintf(const char *fmt, ...) {
char *buffer = NULL;
va_list args, args_copy;
va_start(args, fmt);
va_copy(args_copy, args);
int len = vsnprintf(NULL, 0, fmt, args);
if (len < 0) {
goto finish;
}
buffer = (char *)malloc(len + 1);
if (!buffer) {
goto finish;
}
vsnprintf(buffer, len + 1, fmt, args_copy);
finish:
va_end(args_copy);
va_end(args);
return buffer;
}
static uint32_t ip_hash(uint32_t addr) {
return addr ^ ((addr << 3) | (addr >> 29)) ^ ((addr << 7) | (addr >> 25)) ^
((addr << 13) | (addr >> 19));
}
static void ip_pool_init(ip_pool_t *pool, uint32_t pf, uint32_t pf_mask) {
pool->pf = pf;
pool->pf_mask = pf_mask;
pool->size = 0x1000;
pool->fake = (ip_nat_t **)malloc(pool->size * sizeof(ip_nat_t *));
pool->real = (ip_nat_t **)malloc(pool->size * sizeof(ip_nat_t *));
bzero(pool->fake, pool->size * sizeof(ip_nat_t *));
bzero(pool->real, pool->size * sizeof(ip_nat_t *));
}
static void nat_free_chain(ip_nat_t *nat) {
if (nat) {
nat_free_chain(nat->fake_next);
free(nat);
}
}
static void ip_pool_fini(ip_pool_t *pool) {
for (size_t i = 0; i < pool->size; i++) {
nat_free_chain(pool->fake[i]);
}
free(pool->fake);
free(pool->real);
}
static void nat_expire(EV_P_ ev_timer *w, int revents) {
ip_nat_t *nat = (ip_nat_t *)w;
ip_pool_t *pool = nat->pool;
ev_timer_stop(EV_A, w);
char *cmd = malloc_sprintf("delete rule ip nat postrouting handle %d",
nat->src_handle);
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
cmd = malloc_sprintf("delete rule ip nat prerouting handle %d",
nat->dst_handle);
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
uint32_t mask = pool->size - 1;
ip_nat_t **p_other = &pool->real[nat->real & mask];
ip_nat_t *other = *p_other;
for (; other; other = other->real_next) {
if (nat == other) {
*p_other = nat->real_next;
break;
}
p_other = &other->real_next;
other = *p_other;
}
p_other = &pool->fake[nat->fake & mask];
other = *p_other;
for (; other; other = other->fake_next) {
if (nat == other) {
*p_other = nat->fake_next;
break;
}
p_other = &other->fake_next;
other = *p_other;
}
free(nat);
}
static ip_nat_t *find_nat(EV_P_ ip_pool_t *pool, uint32_t real_addr) {
size_t mask = pool->size - 1;
uint32_t h = ip_hash(real_addr);
uint32_t fake_addr;
for (ip_nat_t *nat = pool->real[real_addr & mask]; nat;
nat = nat->real_next) {
if (nat->real == real_addr) {
ev_timer_again(EV_A, &nat->expire);
return nat;
}
}
int ok = 0;
for (; h != 0;) {
fake_addr = pool->pf | (h & ~pool->pf_mask);
ok = 1;
for (ip_nat_t *nat = pool->fake[fake_addr & mask]; nat;
nat = nat->fake_next) {
if (nat->fake == fake_addr) {
ok = 0;
break;
}
}
if (ok) {
break;
}
h++;
}
if (h == 0) {
return NULL;
} else {
ip_nat_t *nat = (ip_nat_t *)malloc(sizeof(ip_nat_t));
nat->pool = pool;
ev_init(&nat->expire, nat_expire);
nat->expire.repeat = NAT_TTL;
ev_timer_again(EV_A, &nat->expire);
nat->fake = fake_addr;
nat->real = real_addr;
nat->fake_next = pool->fake[fake_addr & mask];
pool->fake[fake_addr & mask] = nat;
nat->real_next = pool->real[real_addr & mask];
pool->real[real_addr & mask] = nat;
char real_ip[16], fake_ip[16];
uint32_t af_fake = htonl(fake_addr);
uint32_t af_real = htonl(real_addr);
inet_ntop(AF_INET, &af_fake, fake_ip, 16);
inet_ntop(AF_INET, &af_real, real_ip, 16);
nft_ctx_buffer_output(nft_ctx);
char *cmd = malloc_sprintf(
"add rule ip nat prerouting ip daddr %s dnat to %s", fake_ip, real_ip);
nft_run_cmd_from_buffer(nft_ctx, cmd);
char *echo_fmt = malloc_sprintf("%s # handle %%d", cmd);
sscanf(nft_ctx_get_output_buffer(nft_ctx), echo_fmt, &nat->dst_handle);
free(echo_fmt);
free(cmd);
nft_ctx_unbuffer_output(nft_ctx);
nft_ctx_buffer_output(nft_ctx);
cmd = malloc_sprintf("add rule ip nat postrouting ip saddr %s snat to %s",
real_ip, fake_ip);
nft_run_cmd_from_buffer(nft_ctx, cmd);
echo_fmt = malloc_sprintf("%s # handle %%d", cmd);
sscanf(nft_ctx_get_output_buffer(nft_ctx), echo_fmt, &nat->src_handle);
free(echo_fmt);
free(cmd);
nft_ctx_unbuffer_output(nft_ctx);
return nat;
}
}
static void free_domain_names(domain_name_t *head) {
if (head) {
free_domain_names(head->next);
free(head);
}
}
static ssize_t read_domain_name(domain_msg_t *msg, ssize_t ptr,
domain_name_t **p_name) {
domain_name_t *name = NULL;
for (name = msg->name_head; name; name = name->next) {
if (name->ptr == ptr) {
*p_name = name;
return 0;
}
}
if (ptr >= msg->len) {
*p_name = NULL;
return -1;
}
uint8_t label_len = *(msg->raw + ptr);
if ((label_len & 0xc0) == 0) {
name = (domain_name_t *)malloc(sizeof(domain_name_t));
name->ptr = ptr;
ptr++;
if (label_len == 0) {
name->match = msg->match_root;
} else {
if (ptr + label_len > msg->len) {
goto label_fail;
}
domain_name_t *parent;
ssize_t new_ptr = read_domain_name(msg, ptr + label_len, &parent);
if (!parent) {
goto label_fail;
}
if (parent->match && parent->match->head) {
domain_set_t *child;
for (child = parent->match->head; child; child = child->next) {
if (child->label_len == label_len &&
!memcmp(child->label, msg->raw + ptr, label_len)) {
break;
}
}
name->match = child;
} else {
name->match = parent->match;
}
ptr = new_ptr;
}
name->next = msg->name_head;
msg->name_head = name;
*p_name = name;
return ptr;
label_fail:
free(name);
*p_name = NULL;
return -1;
} else if ((label_len & 0xc0) == 0xc0) {
ptr++;
if (ptr >= msg->len) {
*p_name = NULL;
goto ptr_fail;
}
ssize_t new_ptr = ((label_len & 0x3f) << 8) | *(msg->raw + ptr);
if (new_ptr >= msg->len) {
*p_name = NULL;
goto ptr_fail;
}
ptr++;
if (read_domain_name(msg, new_ptr, p_name) < 0) {
goto ptr_fail;
} else {
return ptr;
}
ptr_fail:
return -1;
} else {
*p_name = NULL;
return -1;
}
}
static void client_read(EV_P_ ev_io *w, int revents) {
client_ctx_t *client_ctx = (client_ctx_t *)w;
ev_timer_stop(EV_A, &client_ctx->timer);
ev_io_stop(EV_A, &client_ctx->io);
domain_msg_t *msg = &client_ctx->msg;
msg->match_root = client_ctx->match_root;
msg->name_head = NULL;
msg->raw = malloc(MAX_MESSAGE_SIZE);
msg->len = recvfrom(client_ctx->fd, msg->raw, MAX_MESSAGE_SIZE, MSG_TRUNC,
NULL, NULL);
if (msg->len < sizeof(dns_header_t)) {
goto fail;
}
dns_header_t *header = (dns_header_t *)msg->raw;
if (msg->len > MAX_MESSAGE_SIZE) {
header->flags = htons(ntohs(header->flags) | DNS_FLAG_TC);
goto send;
}
ssize_t ptr = sizeof(dns_header_t);
for (uint16_t i = 0; i < ntohs(header->qd_count); i++) {
domain_name_t *name;
ssize_t new_ptr = read_domain_name(msg, ptr, &name);
if (new_ptr < 0) {
goto fail;
}
ptr = new_ptr;
// Skip type and class
ptr += 4;
}
for (uint16_t i = 0; i < ntohs(header->an_count); i++) {
domain_name_t *name;
ssize_t new_ptr = read_domain_name(msg, ptr, &name);
if (new_ptr < 0) {
goto fail;
}
ptr = new_ptr;
dns_answer_header_t *an_header = (dns_answer_header_t *)(msg->raw + ptr);
ptr += sizeof(dns_answer_header_t);
if (name->match && ntohs(an_header->an_type) == DNS_TYPE_A &&
ntohs(an_header->an_class) == DNS_CLASS_IN &&
ntohs(an_header->rd_len) == 4) {
// Replace answer with fake ip
uint32_t *p_addr = (uint32_t *)(msg->raw + ptr);
uint32_t real_addr = ntohl(*p_addr);
ip_nat_t *nat = find_nat(EV_A, client_ctx->ip_pool, real_addr);
if (nat) {
*p_addr = htonl(nat->fake);
an_header->an_ttl = htonl(NAT_TTL);
}
}
ptr += ntohs(an_header->rd_len);
}
send:
sendto(client_ctx->server_fd, msg->raw, msg->len, 0,
(struct sockaddr *)&client_ctx->client_addr,
client_ctx->client_addr_len);
fail:
free_domain_names(msg->name_head);
free(msg->raw);
close(client_ctx->fd);
free(client_ctx);
}
static void client_timeout(EV_P_ ev_timer *w, int revents) {
client_ctx_t *client_ctx =
(client_ctx_t *)((uint8_t *)w - offsetof(client_ctx_t, timer));
ev_io_stop(EV_A_ & client_ctx->io);
close(client_ctx->fd);
free(client_ctx);
}
static void server_read(EV_P_ ev_io *w, int revents) {
server_ctx_t *server_ctx = (server_ctx_t *)w;
uint8_t *msg = malloc(MAX_MESSAGE_SIZE);
struct sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
ssize_t len = recvfrom(server_ctx->fd, msg, MAX_MESSAGE_SIZE, 0,
(struct sockaddr *)&addr, &addr_len);
if (len < 0) {
perror("Failed to read request");
goto finish_read;
} else if (len < sizeof(dns_header_t)) {
fprintf(stderr, "Invalid request: Packet too short.\n");
goto finish_read;
}
client_ctx_t *client_ctx = malloc(sizeof(client_ctx_t));
client_ctx->ip_pool = server_ctx->ip_pool;
client_ctx->match_root = server_ctx->match_root;
client_ctx->server_fd = server_ctx->fd;
client_ctx->client_addr_len = addr_len;
memcpy(&client_ctx->client_addr, &addr, addr_len);
client_ctx->fd = socket(AF_INET, SOCK_DGRAM, 0);
if (client_ctx->fd < 0) {
free(client_ctx);
perror("Failed to create client socket");
goto finish_read;
}
ev_io_init(&client_ctx->io, &client_read, client_ctx->fd, EV_READ);
ev_timer_init(&client_ctx->timer, &client_timeout, 5., 0.);
ev_io_start(EV_A, &client_ctx->io);
sendto(client_ctx->fd, msg, len, 0,
(struct sockaddr *)&server_ctx->upstream_addr,
sizeof(server_ctx->upstream_addr));
ev_timer_start(EV_A, &client_ctx->timer);
finish_read:
free(msg);
}
static domain_set_t *domain_set_new(char const *label, size_t len) {
domain_set_t *s = (domain_set_t *)malloc(sizeof(domain_set_t));
s->next = NULL;
s->head = NULL;
s->label = label;
s->label_len = len;
return s;
}
static void domain_set_fini(domain_set_t *s) {
domain_set_t *c_next;
for (domain_set_t *c = s->head; c; c = c_next) {
c_next = c->next;
domain_set_fini(c);
}
free(s);
}
static domain_set_t *parse_domain_option(domain_set_t *root,
char const *domain) {
char const *p = domain;
if (*p == 0 || *p == '.' || *p == '-') {
return NULL;
}
for (p = domain; *p && *p != '.'; p++) {
if (!(isalnum(*p) || *p == '-')) {
return NULL;
}
}
if (*(p - 1) == '-') {
return NULL;
}
size_t label_len = p - domain;
if (label_len > 63) {
return NULL;
}
domain_set_t *parent;
if (*p == '.') {
p++;
parent = parse_domain_option(root, p);
} else {
parent = root;
}
if (!parent) {
return NULL;
}
for (domain_set_t *c = parent->head; c; c = c->next) {
if (c->label_len == label_len && !memcmp(c->label, domain, label_len)) {
return c;
}
}
domain_set_t *s = domain_set_new(domain, label_len);
s->next = parent->head;
parent->head = s;
return s;
}
static void terminate(EV_P_ ev_signal *w, int revents) {
ev_break(EV_A, EVBREAK_ALL);
}
#define OPT_UPSTREAM_HOST 0x101
#define OPT_UPSTREAM_PORT 0x102
int main(int argc, char *const *argv) {
domain_set_t *domain_set = domain_set_new("", 0);
ip_pool_t ip_pool;
struct sockaddr_in listen_addr, upstream_addr;
memset(&listen_addr, 0, sizeof(struct sockaddr_in));
memset(&upstream_addr, 0, sizeof(struct sockaddr_in));
uint16_t listen_port = 53;
uint16_t upstream_port = 53;
listen_addr.sin_family = AF_INET;
upstream_addr.sin_family = AF_INET;
const char *options = "h:p:d:x:";
struct option long_options[] = {
{"host", required_argument, NULL, 'h'},
{"port", required_argument, NULL, 'p'},
{"domain", required_argument, NULL, 'd'},
{"prefix", required_argument, NULL, 'x'},
{"upstream-host", required_argument, NULL, OPT_UPSTREAM_HOST},
{"upstream-port", required_argument, NULL, OPT_UPSTREAM_PORT},
{NULL, 0, NULL, 0}};
int option_index;
int o;
opterr = 0;
int listen_host_set = 0;
int prefix_set = 0;
int upstream_host_set = 0;
while ((o = getopt_long(argc, argv, options, long_options, &option_index)) !=
-1) {
switch (o) {
case 'h':
if (inet_pton(AF_INET, optarg, &listen_addr.sin_addr) != 1) {
goto fail;
} else {
listen_host_set = 1;
}
break;
case 'p':
if (sscanf(optarg, "%hu", &listen_port) == -1) {
goto fail;
}
break;
case 'd':
if (!parse_domain_option(domain_set, optarg)) {
goto fail;
}
break;
case 'x': {
char *sep = strchr(optarg, '/');
if (!sep) {
goto fail;
}
*sep = 0;
uint32_t af_addr;
unsigned pf_len;
if (inet_pton(AF_INET, optarg, &af_addr) != 1 ||
sscanf(sep + 1, "%u", &pf_len) == -1 || pf_len > 30) {
goto fail;
}
uint32_t pf = ntohl(af_addr);
uint32_t pf_mask = ~((1 << (32 - pf_len)) - 1);
if (pf & ~pf_mask) {
goto fail;
}
ip_pool_init(&ip_pool, pf, pf_mask);
prefix_set = 1;
break;
}
case OPT_UPSTREAM_HOST:
if (inet_pton(AF_INET, optarg, &upstream_addr.sin_addr) < 0) {
goto fail;
}
upstream_host_set = 1;
break;
case OPT_UPSTREAM_PORT:
if (sscanf(optarg, "%hu", &upstream_port) == -1) {
goto fail;
}
break;
fail:
opterr = 1;
fprintf(stderr, "Failed to parse option: %s", optarg);
break;
case '?':
fprintf(stderr, "Unrecognized option: %s\n", optarg);
opterr = 1;
fprintf(stderr,
"Usage: %s -h LISTEN_HOST -p LISTEN_PORT\n"
" -x FAKE_IP_PREFIX [-d DOMAIN]\n"
" --upstream-host UPSTREAM_HOST --upstream-port "
"UPSTREAM_PORT\n",
argv[0]);
break;
}
}
if (!(listen_host_set && prefix_set && upstream_host_set)) {
fprintf(stderr,
"LISTEN_ADDR, FAKE_IP_PREFIX, UPSTREAM_HOST must be set.\n");
opterr = 1;
}
if (opterr) {
domain_set_fini(domain_set);
return 1;
}
listen_addr.sin_port = htons(listen_port);
upstream_addr.sin_port = htons(upstream_port);
nft_ctx = nft_ctx_new(NFT_CTX_DEFAULT);
if (!nft_ctx) {
perror("Failed to create nftables context.");
}
nft_ctx_output_set_flags(nft_ctx,
NFT_CTX_OUTPUT_HANDLE | NFT_CTX_OUTPUT_ECHO);
char *cmd = malloc_sprintf("add table ip nat");
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
cmd = malloc_sprintf("add chain ip nat prerouting"
"{ type nat hook prerouting priority dstnat; }");
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
cmd = malloc_sprintf("add chain ip nat postrouting"
"{ type nat hook postrouting priority srcnat; }");
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
struct ev_loop *loop = ev_default_loop(0);
int listen_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (bind(listen_fd, (struct sockaddr *)&listen_addr, sizeof(listen_addr)) ==
0) {
server_ctx_t server_ctx;
server_ctx.upstream_addr = upstream_addr;
server_ctx.match_root = domain_set;
server_ctx.ip_pool = &ip_pool;
server_ctx.fd = listen_fd;
ev_io_init(&server_ctx.io, &server_read, listen_fd, EV_READ);
ev_signal sig_int, sig_term;
ev_signal_init(&sig_int, &terminate, SIGINT);
ev_signal_init(&sig_term, &terminate, SIGTERM);
ev_signal_start(loop, &sig_int);
ev_signal_start(loop, &sig_term);
ev_io_start(loop, &server_ctx.io);
ev_loop(loop, 0);
} else {
perror("Failed to bind");
}
close(listen_fd);
ip_pool_fini(&ip_pool);
cmd = malloc_sprintf("delete table ip nat");
nft_run_cmd_from_buffer(nft_ctx, cmd);
free(cmd);
nft_ctx_free(nft_ctx);
domain_set_fini(domain_set);
return 0;
}