diff --git a/dnsf.py b/dnsf.py new file mode 100644 index 0000000..63d0f95 --- /dev/null +++ b/dnsf.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +from bcc import BPF +import dnslib +import pprint +import fcntl +import os +import sys,time,re +from datetime import datetime, timedelta +from multiprocessing import Pool + +def run(x): + import os + if type(x) == list: + for c in x: + os.system(c) + else: + os.system(x) + +def threaded(commands): + with Pool(8) as p: + p.map(run, commands) + +class watch_file(): + def __init__(self,fname): + import os + print('starting file watcher.') + self.filename = os.path.abspath(fname) + self.cstamp = os.stat(self.filename).st_mtime-1 + + def check_change(self): + self.contents = None + import os + import sys + import psutil + stamp = os.stat(self.filename).st_mtime + if stamp != self.cstamp: + print('file change detected. loading.') + self.cstamp = stamp + with open(self.filename,'r') as f: + lines = f.readlines() + lines = [x.strip() for x in lines] + self.contents = lines + return lines + else: + return False + + +BPF_APP = r''' +#include +#include +#include + +int dns_matching(struct __sk_buff *skb) { + u8 *cursor = 0; + + // Checking the IP protocol: + struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet)); + + if (ethernet->type == ETH_P_IP) { + // Checking the UDP protocol: + struct ip_t *ip = cursor_advance(cursor, sizeof(*ip)); + + if (ip->nextp == IPPROTO_UDP) { + // Check the port 53: + struct udp_t *udp = cursor_advance(cursor, sizeof(*udp)); + + if (udp->dport == 53 || udp->sport == 53) { + return -1; + } + } + } + return 0; +} +''' + +default_rules = ''' +iptables -F +ipset destroy whitelist_hosts +ipset destroy static_hosts +ipset create whitelist_hosts hash:ip timeout 3600 +ipset create static_hosts hash:ip +iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT +iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT +iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT +iptables -A INPUT -i lo -j ACCEPT +iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT +iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT +iptables -A OUTPUT -j REJECT +iptables -P INPUT DROP +iptables -P FORWARD DROP +iptables -P OUTPUT DROP +''' + +for c in default_rules.splitlines(False): + print(c.strip()) + os.system(c.strip()) + +bpf = BPF(text=BPF_APP) +function_dns_matching = bpf.load_func("dns_matching", BPF.SOCKET_FILTER) +BPF.attach_raw_socket(function_dns_matching, '') + +socket_fd = function_dns_matching.sock +fl = fcntl.fcntl(socket_fd, fcntl.F_GETFL) +fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK) + +rules = {} +cnames = {} + + +def check_hostname(hostname,current_cnames,allow_list): + import re + for wl in allow_list: + match = re.fullmatch(wl, hostname) + if match: + print(' hostname: '+str(hostname) + ' in whitelist') + return True + + for c in current_cnames.keys(): + if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now(): + print(' hostname: '+str(hostname) + ' in cname') + return True + print(' hostname: '+str(hostname) + ' not permitted.') + return False + + + +import socket + +def valid_ip(address): + try: + socket.inet_aton(address) + return True + except: + return False + +dns_list = watch_file('dns_whitelist.conf') +host_list = watch_file('host_whitelist.conf') + +import time +while True: + time.sleep(.05) + commands = [] + + dns_list.check_change() + if dns_list.contents: + dns_regex = dns_list.contents + + host_list.check_change() + if host_list.contents: + threaded(['ipset flush static_hosts']) + hosts = [x for x in host_list.contents if valid_ip(x)] + for host in hosts: + commands.append('ipset add static_hosts '+host) + print(host) + threaded(commands) + import select + r, w, e = select.select([ socket_fd ], [], [], 0) + if socket_fd in r: + try: + packet_str = os.read(socket_fd, 2048) + except KeyboardInterrupt: + sys.exit(0) + + packet_bytearray = bytearray(packet_str) + + ETH_HLEN = 14 + UDP_HLEN = 8 + + # IP header length + ip_header_length = packet_bytearray[ETH_HLEN] + ip_header_length = ip_header_length & 0x0F + ip_header_length = ip_header_length << 2 + + # Starting the DNS packet + payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN + payload = packet_bytearray[payload_offset:] + dnsrec = dnslib.DNSRecord.parse(payload) + + if dnsrec.rr: + if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex): + for i in range(0, len(dnsrec.rr)): + if str(dnsrec.rr[i].rtype) in ['5']: + cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl)} + + if str(dnsrec.rr[i].rtype) in ['1']: + #rtype 28 is v6 + #print(str(dnsrec.rr[i].rtype)) + add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+360) + del_rule = 'ipset del whitelist_hosts '+str(dnsrec.rr[i].rdata) + print(add_rule) + commands.append(add_rule) + + threaded(commands) + else: + pass + else: + pass +