diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1b94b54 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +CC=gcc +CFLAGS=-Wall +LDFLAGS= +DEPS=$(wildcard *.h) +SRC=$(wildcard src/*.c) +OBJ=$(patsubst src/%.c, build/%.o, $(SRC)) + +.PHONY: all clean + +all: build/dotp + +clean: + rm $(OBJ) + rm build/dotp + +build/%.o: src/%.c $(DEPS) + mkdir -p build + $(CC) -c -o $@ $< $(CFLAGS) + +build/dotp: $(OBJ) + mkdir -p build + $(CC) -o $@ $^ $(LDFLAGS) -lev -lnftables diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..bca310a --- /dev/null +++ b/src/main.c @@ -0,0 +1,764 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_MESSAGE_SIZE 0x200 + +#define DNS_FLAG_TC 0x200 +#define DNS_TYPE_A 1 +#define DNS_CLASS_IN 0x0001 + +#define NAT_TTL 60 + +typedef struct __attribute__((packed)) dns_header { + uint16_t id; + uint16_t flags; + uint16_t qd_count; + uint16_t an_count; + uint16_t ns_count; + uint16_t ar_count; +} dns_header_t; + +typedef struct __attribute__((packed)) dns_answer_header { + uint16_t an_type; + uint16_t an_class; + uint32_t an_ttl; + uint16_t rd_len; +} dns_answer_header_t; + +typedef struct domain_set { + struct domain_set *next; + struct domain_set *head; + + uint8_t label_len; + char const *label; +} domain_set_t; + +typedef struct domain_name { + struct domain_name *next; + ssize_t ptr; + domain_set_t *match; +} domain_name_t; + +typedef struct domain_msg { + domain_set_t *match_root; + + ssize_t len; + uint8_t *raw; + + domain_name_t *name_head; +} domain_msg_t; + +typedef struct ip_pool ip_pool_t; + +typedef struct ip_nat { + ev_timer expire; + ip_pool_t *pool; + struct ip_nat *fake_next, *real_next; + uint32_t fake, real; + int dst_handle, src_handle; +} ip_nat_t; + +typedef struct ip_pool { + uint32_t pf; + uint32_t pf_mask; + + // Indexed by fake + size_t size; + ip_nat_t **fake; + ip_nat_t **real; +} ip_pool_t; + +typedef struct server_ctx { + ev_io io; + + domain_set_t *match_root; + ip_pool_t *ip_pool; + int fd; + struct sockaddr_in upstream_addr; +} server_ctx_t; + +typedef struct client_ctx { + ev_io io; + ev_timer timer; + + domain_set_t *match_root; + ip_pool_t *ip_pool; + int server_fd; + socklen_t client_addr_len; + struct sockaddr_storage client_addr; + int fd; + domain_msg_t msg; +} client_ctx_t; + +static struct nft_ctx *nft_ctx; + +char *malloc_sprintf(const char *fmt, ...) { + char *buffer = NULL; + + va_list args, args_copy; + va_start(args, fmt); + va_copy(args_copy, args); + + int len = vsnprintf(NULL, 0, fmt, args); + if (len < 0) { + goto finish; + } + + buffer = (char *)malloc(len + 1); + if (!buffer) { + goto finish; + } + + vsnprintf(buffer, len + 1, fmt, args_copy); + +finish: + va_end(args_copy); + va_end(args); + return buffer; +} + +static uint32_t ip_hash(uint32_t addr) { + return addr ^ ((addr << 3) | (addr >> 29)) ^ ((addr << 7) | (addr >> 25)) ^ + ((addr << 13) | (addr >> 19)); +} + +static void ip_pool_init(ip_pool_t *pool, uint32_t pf, uint32_t pf_mask) { + pool->pf = pf; + pool->pf_mask = pf_mask; + + pool->size = 0x1000; + + pool->fake = (ip_nat_t **)malloc(pool->size * sizeof(ip_nat_t *)); + pool->real = (ip_nat_t **)malloc(pool->size * sizeof(ip_nat_t *)); + + bzero(pool->fake, pool->size * sizeof(ip_nat_t *)); + bzero(pool->real, pool->size * sizeof(ip_nat_t *)); +} + +static void nat_free_chain(ip_nat_t *nat) { + if (nat) { + nat_free_chain(nat->fake_next); + free(nat); + } +} + +static void ip_pool_fini(ip_pool_t *pool) { + for (size_t i = 0; i < pool->size; i++) { + nat_free_chain(pool->fake[i]); + } + free(pool->fake); + free(pool->real); +} + +static void nat_expire(EV_P_ ev_timer *w, int revents) { + ip_nat_t *nat = (ip_nat_t *)w; + ip_pool_t *pool = nat->pool; + + ev_timer_stop(EV_A, w); + + char *cmd = malloc_sprintf("delete rule ip nat postrouting handle %d", + nat->src_handle); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + cmd = malloc_sprintf("delete rule ip nat prerouting handle %d", + nat->dst_handle); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + uint32_t mask = pool->size - 1; + + ip_nat_t **p_other = &pool->real[nat->real & mask]; + ip_nat_t *other = *p_other; + + for (; other; other = other->real_next) { + if (nat == other) { + *p_other = nat->real_next; + break; + } + p_other = &other->real_next; + other = *p_other; + } + + p_other = &pool->fake[nat->fake & mask]; + other = *p_other; + + for (; other; other = other->fake_next) { + if (nat == other) { + *p_other = nat->fake_next; + break; + } + p_other = &other->fake_next; + other = *p_other; + } + + free(nat); +} + +static ip_nat_t *find_nat(EV_P_ ip_pool_t *pool, uint32_t real_addr) { + size_t mask = pool->size - 1; + + uint32_t h = ip_hash(real_addr); + uint32_t fake_addr; + for (ip_nat_t *nat = pool->real[real_addr & mask]; nat; + nat = nat->real_next) { + if (nat->real == real_addr) { + ev_timer_again(EV_A, &nat->expire); + return nat; + } + } + + int ok = 0; + for (; h != 0;) { + fake_addr = pool->pf | (h & ~pool->pf_mask); + ok = 1; + for (ip_nat_t *nat = pool->fake[fake_addr & mask]; nat; + nat = nat->fake_next) { + if (nat->fake == fake_addr) { + ok = 0; + break; + } + } + if (ok) { + break; + } + h++; + } + if (h == 0) { + return NULL; + } else { + ip_nat_t *nat = (ip_nat_t *)malloc(sizeof(ip_nat_t)); + + nat->pool = pool; + + ev_init(&nat->expire, nat_expire); + nat->expire.repeat = NAT_TTL; + ev_timer_again(EV_A, &nat->expire); + + nat->fake = fake_addr; + nat->real = real_addr; + + nat->fake_next = pool->fake[fake_addr & mask]; + pool->fake[fake_addr & mask] = nat; + + nat->real_next = pool->real[real_addr & mask]; + pool->real[real_addr & mask] = nat; + + char real_ip[16], fake_ip[16]; + uint32_t af_fake = htonl(fake_addr); + uint32_t af_real = htonl(real_addr); + + inet_ntop(AF_INET, &af_fake, fake_ip, 16); + inet_ntop(AF_INET, &af_real, real_ip, 16); + + nft_ctx_buffer_output(nft_ctx); + char *cmd = malloc_sprintf( + "add rule ip nat prerouting ip daddr %s dnat to %s", fake_ip, real_ip); + nft_run_cmd_from_buffer(nft_ctx, cmd); + char *echo_fmt = malloc_sprintf("%s # handle %%d", cmd); + sscanf(nft_ctx_get_output_buffer(nft_ctx), echo_fmt, &nat->dst_handle); + free(echo_fmt); + free(cmd); + nft_ctx_unbuffer_output(nft_ctx); + + nft_ctx_buffer_output(nft_ctx); + cmd = malloc_sprintf("add rule ip nat postrouting ip saddr %s snat to %s", + real_ip, fake_ip); + nft_run_cmd_from_buffer(nft_ctx, cmd); + echo_fmt = malloc_sprintf("%s # handle %%d", cmd); + sscanf(nft_ctx_get_output_buffer(nft_ctx), echo_fmt, &nat->src_handle); + free(echo_fmt); + free(cmd); + nft_ctx_unbuffer_output(nft_ctx); + + return nat; + } +} + +static void free_domain_names(domain_name_t *head) { + if (head) { + free_domain_names(head->next); + free(head); + } +} + +static ssize_t read_domain_name(domain_msg_t *msg, ssize_t ptr, + domain_name_t **p_name) { + domain_name_t *name = NULL; + + for (name = msg->name_head; name; name = name->next) { + if (name->ptr == ptr) { + *p_name = name; + return 0; + } + } + if (ptr >= msg->len) { + *p_name = NULL; + return -1; + } + uint8_t label_len = *(msg->raw + ptr); + if ((label_len & 0xc0) == 0) { + name = (domain_name_t *)malloc(sizeof(domain_name_t)); + name->ptr = ptr; + ptr++; + + if (label_len == 0) { + name->match = msg->match_root; + } else { + if (ptr + label_len > msg->len) { + goto label_fail; + } + domain_name_t *parent; + ssize_t new_ptr = read_domain_name(msg, ptr + label_len, &parent); + if (!parent) { + goto label_fail; + } + if (parent->match && parent->match->head) { + domain_set_t *child; + for (child = parent->match->head; child; child = child->next) { + if (child->label_len == label_len && + !memcmp(child->label, msg->raw + ptr, label_len)) { + break; + } + } + name->match = child; + } else { + name->match = parent->match; + } + ptr = new_ptr; + } + + name->next = msg->name_head; + msg->name_head = name; + *p_name = name; + return ptr; + + label_fail: + free(name); + *p_name = NULL; + return -1; + + } else if ((label_len & 0xc0) == 0xc0) { + ptr++; + if (ptr >= msg->len) { + *p_name = NULL; + goto ptr_fail; + } + ssize_t new_ptr = ((label_len & 0x3f) << 8) | *(msg->raw + ptr); + if (new_ptr >= msg->len) { + *p_name = NULL; + goto ptr_fail; + } + ptr++; + if (read_domain_name(msg, new_ptr, p_name) < 0) { + goto ptr_fail; + } else { + return ptr; + } + + ptr_fail: + return -1; + } else { + *p_name = NULL; + return -1; + } +} + +static void client_read(EV_P_ ev_io *w, int revents) { + client_ctx_t *client_ctx = (client_ctx_t *)w; + + ev_timer_stop(EV_A, &client_ctx->timer); + ev_io_stop(EV_A, &client_ctx->io); + + domain_msg_t *msg = &client_ctx->msg; + msg->match_root = client_ctx->match_root; + msg->name_head = NULL; + + msg->raw = malloc(MAX_MESSAGE_SIZE); + + msg->len = recvfrom(client_ctx->fd, msg->raw, MAX_MESSAGE_SIZE, MSG_TRUNC, + NULL, NULL); + if (msg->len < sizeof(dns_header_t)) { + goto fail; + } + + dns_header_t *header = (dns_header_t *)msg->raw; + if (msg->len > MAX_MESSAGE_SIZE) { + header->flags = htons(ntohs(header->flags) | DNS_FLAG_TC); + goto send; + } + + ssize_t ptr = sizeof(dns_header_t); + + for (uint16_t i = 0; i < ntohs(header->qd_count); i++) { + domain_name_t *name; + ssize_t new_ptr = read_domain_name(msg, ptr, &name); + if (new_ptr < 0) { + goto fail; + } + ptr = new_ptr; + // Skip type and class + ptr += 4; + } + + for (uint16_t i = 0; i < ntohs(header->an_count); i++) { + domain_name_t *name; + ssize_t new_ptr = read_domain_name(msg, ptr, &name); + if (new_ptr < 0) { + goto fail; + } + + ptr = new_ptr; + + dns_answer_header_t *an_header = (dns_answer_header_t *)(msg->raw + ptr); + ptr += sizeof(dns_answer_header_t); + + if (name->match && ntohs(an_header->an_type) == DNS_TYPE_A && + ntohs(an_header->an_class) == DNS_CLASS_IN && + ntohs(an_header->rd_len) == 4) { + // Replace answer with fake ip + + uint32_t *p_addr = (uint32_t *)(msg->raw + ptr); + + uint32_t real_addr = ntohl(*p_addr); + ip_nat_t *nat = find_nat(EV_A, client_ctx->ip_pool, real_addr); + + if (nat) { + *p_addr = htonl(nat->fake); + an_header->an_ttl = htonl(NAT_TTL); + } + } + + ptr += ntohs(an_header->rd_len); + } + +send: + sendto(client_ctx->server_fd, msg->raw, msg->len, 0, + (struct sockaddr *)&client_ctx->client_addr, + client_ctx->client_addr_len); + +fail: + free_domain_names(msg->name_head); + free(msg->raw); + + close(client_ctx->fd); + free(client_ctx); +} + +static void client_timeout(EV_P_ ev_timer *w, int revents) { + client_ctx_t *client_ctx = + (client_ctx_t *)((uint8_t *)w - offsetof(client_ctx_t, timer)); + + ev_io_stop(EV_A_ & client_ctx->io); + + close(client_ctx->fd); + free(client_ctx); +} + +static void server_read(EV_P_ ev_io *w, int revents) { + server_ctx_t *server_ctx = (server_ctx_t *)w; + + uint8_t *msg = malloc(MAX_MESSAGE_SIZE); + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + ssize_t len = recvfrom(server_ctx->fd, msg, MAX_MESSAGE_SIZE, 0, + (struct sockaddr *)&addr, &addr_len); + + if (len < 0) { + perror("Failed to read request"); + goto finish_read; + } else if (len < sizeof(dns_header_t)) { + fprintf(stderr, "Invalid request: Packet too short.\n"); + goto finish_read; + } + + client_ctx_t *client_ctx = malloc(sizeof(client_ctx_t)); + + client_ctx->ip_pool = server_ctx->ip_pool; + + client_ctx->match_root = server_ctx->match_root; + client_ctx->server_fd = server_ctx->fd; + + client_ctx->client_addr_len = addr_len; + memcpy(&client_ctx->client_addr, &addr, addr_len); + + client_ctx->fd = socket(AF_INET, SOCK_DGRAM, 0); + if (client_ctx->fd < 0) { + free(client_ctx); + perror("Failed to create client socket"); + goto finish_read; + } + + ev_io_init(&client_ctx->io, &client_read, client_ctx->fd, EV_READ); + ev_timer_init(&client_ctx->timer, &client_timeout, 5., 0.); + + ev_io_start(EV_A, &client_ctx->io); + + sendto(client_ctx->fd, msg, len, 0, + (struct sockaddr *)&server_ctx->upstream_addr, + sizeof(server_ctx->upstream_addr)); + + ev_timer_start(EV_A, &client_ctx->timer); + +finish_read: + free(msg); +} + +static domain_set_t *domain_set_new(char const *label, size_t len) { + domain_set_t *s = (domain_set_t *)malloc(sizeof(domain_set_t)); + s->next = NULL; + s->head = NULL; + s->label = label; + s->label_len = len; + return s; +} + +static void domain_set_fini(domain_set_t *s) { + domain_set_t *c_next; + for (domain_set_t *c = s->head; c; c = c_next) { + c_next = c->next; + domain_set_fini(c); + } + free(s); +} + +static domain_set_t *parse_domain_option(domain_set_t *root, + char const *domain) { + char const *p = domain; + if (*p == 0 || *p == '.' || *p == '-') { + return NULL; + } + for (p = domain; *p && *p != '.'; p++) { + if (!(isalnum(*p) || *p == '-')) { + return NULL; + } + } + if (*(p - 1) == '-') { + return NULL; + } + size_t label_len = p - domain; + if (label_len > 63) { + return NULL; + } + domain_set_t *parent; + if (*p == '.') { + p++; + parent = parse_domain_option(root, p); + } else { + parent = root; + } + if (!parent) { + return NULL; + } + for (domain_set_t *c = parent->head; c; c = c->next) { + if (c->label_len == label_len && !memcmp(c->label, domain, label_len)) { + return c; + } + } + domain_set_t *s = domain_set_new(domain, label_len); + s->next = parent->head; + parent->head = s; + return s; +} + +static void terminate(EV_P_ ev_signal *w, int revents) { + ev_break(EV_A, EVBREAK_ALL); +} + +#define OPT_UPSTREAM_HOST 0x101 +#define OPT_UPSTREAM_PORT 0x102 + +int main(int argc, char *const *argv) { + domain_set_t *domain_set = domain_set_new("", 0); + ip_pool_t ip_pool; + + struct sockaddr_in listen_addr, upstream_addr; + memset(&listen_addr, 0, sizeof(struct sockaddr_in)); + memset(&upstream_addr, 0, sizeof(struct sockaddr_in)); + + uint16_t listen_port = 53; + uint16_t upstream_port = 53; + listen_addr.sin_family = AF_INET; + upstream_addr.sin_family = AF_INET; + + const char *options = "h:p:d:x:"; + struct option long_options[] = { + {"host", required_argument, NULL, 'h'}, + {"port", required_argument, NULL, 'p'}, + {"domain", required_argument, NULL, 'd'}, + {"prefix", required_argument, NULL, 'x'}, + {"upstream-host", required_argument, NULL, OPT_UPSTREAM_HOST}, + {"upstream-port", required_argument, NULL, OPT_UPSTREAM_PORT}, + {NULL, 0, NULL, 0}}; + int option_index; + + int o; + opterr = 0; + + int listen_host_set = 0; + int prefix_set = 0; + int upstream_host_set = 0; + + while ((o = getopt_long(argc, argv, options, long_options, &option_index)) != + -1) { + switch (o) { + case 'h': + if (inet_pton(AF_INET, optarg, &listen_addr.sin_addr) != 1) { + goto fail; + } else { + listen_host_set = 1; + } + break; + case 'p': + if (sscanf(optarg, "%hu", &listen_port) == -1) { + goto fail; + } + break; + case 'd': + if (!parse_domain_option(domain_set, optarg)) { + goto fail; + } + break; + case 'x': { + char *sep = strchr(optarg, '/'); + if (!sep) { + goto fail; + } + *sep = 0; + + uint32_t af_addr; + unsigned pf_len; + if (inet_pton(AF_INET, optarg, &af_addr) != 1 || + sscanf(sep + 1, "%u", &pf_len) == -1 || pf_len > 30) { + goto fail; + } + uint32_t pf = ntohl(af_addr); + uint32_t pf_mask = ~((1 << (32 - pf_len)) - 1); + if (pf & ~pf_mask) { + goto fail; + } + ip_pool_init(&ip_pool, pf, pf_mask); + prefix_set = 1; + break; + } + case OPT_UPSTREAM_HOST: + if (inet_pton(AF_INET, optarg, &upstream_addr.sin_addr) < 0) { + goto fail; + } + upstream_host_set = 1; + break; + case OPT_UPSTREAM_PORT: + if (sscanf(optarg, "%hu", &upstream_port) == -1) { + goto fail; + } + break; + fail: + opterr = 1; + fprintf(stderr, "Failed to parse option: %s", optarg); + break; + case '?': + fprintf(stderr, "Unrecognized option: %s\n", optarg); + opterr = 1; + fprintf(stderr, + "Usage: %s -h LISTEN_HOST -p LISTEN_PORT\n" + " -x FAKE_IP_PREFIX [-d DOMAIN]\n" + " --upstream-host UPSTREAM_HOST --upstream-port " + "UPSTREAM_PORT\n", + argv[0]); + break; + } + } + + if (!(listen_host_set && prefix_set && upstream_host_set)) { + fprintf(stderr, + "LISTEN_ADDR, FAKE_IP_PREFIX, UPSTREAM_HOST must be set.\n"); + opterr = 1; + } + + if (opterr) { + domain_set_fini(domain_set); + return 1; + } + + listen_addr.sin_port = htons(listen_port); + upstream_addr.sin_port = htons(upstream_port); + + nft_ctx = nft_ctx_new(NFT_CTX_DEFAULT); + if (!nft_ctx) { + perror("Failed to create nftables context."); + } + nft_ctx_output_set_flags(nft_ctx, + NFT_CTX_OUTPUT_HANDLE | NFT_CTX_OUTPUT_ECHO); + + char *cmd = malloc_sprintf("add table ip nat"); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + cmd = malloc_sprintf("add chain ip nat prerouting" + "{ type nat hook prerouting priority dstnat; }"); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + cmd = malloc_sprintf("add chain ip nat postrouting" + "{ type nat hook postrouting priority srcnat; }"); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + struct ev_loop *loop = ev_default_loop(0); + + int listen_fd = socket(AF_INET, SOCK_DGRAM, 0); + + if (bind(listen_fd, (struct sockaddr *)&listen_addr, sizeof(listen_addr)) == + 0) { + server_ctx_t server_ctx; + + server_ctx.upstream_addr = upstream_addr; + + server_ctx.match_root = domain_set; + server_ctx.ip_pool = &ip_pool; + server_ctx.fd = listen_fd; + + ev_io_init(&server_ctx.io, &server_read, listen_fd, EV_READ); + + ev_signal sig_int, sig_term; + ev_signal_init(&sig_int, &terminate, SIGINT); + ev_signal_init(&sig_term, &terminate, SIGTERM); + + ev_signal_start(loop, &sig_int); + ev_signal_start(loop, &sig_term); + + ev_io_start(loop, &server_ctx.io); + + ev_loop(loop, 0); + } else { + perror("Failed to bind"); + } + + close(listen_fd); + + ip_pool_fini(&ip_pool); + + cmd = malloc_sprintf("delete table ip nat"); + nft_run_cmd_from_buffer(nft_ctx, cmd); + free(cmd); + + nft_ctx_free(nft_ctx); + + domain_set_fini(domain_set); + + return 0; +}