#!/usr/bin/env python3 from bcc import BPF import dnslib import fcntl import os import sys,time,re from datetime import datetime, timedelta import select import socket def run(x): print(x) import os if type(x) == list: for c in x: os.system(c) else: os.system(x) class watch_file(): def __init__(self,fname): print('starting file watcher.') self.filename = os.path.abspath(fname) self.cstamp = os.stat(self.filename).st_mtime-1 def check_change(self): self.contents = None stamp = os.stat(self.filename).st_mtime if stamp != self.cstamp: print('file change detected. loading.') self.cstamp = stamp with open(self.filename,'r') as f: lines = f.readlines() lines = [x.strip() for x in lines] self.contents = lines return lines else: return False dns_resolver = "DNS=9.9.9.11#dns11.quad9.net" #systemd-resolved resolved = """ [Resolve] Domains=~. DNSSEC=true DNSOverTLS=yes MulticastDNS=no LLMNR=no Cache=yes DNSStubListener=yes """ with open("/etc/systemd/resolved.conf", "w") as file: file.write(resolved) with open("/etc/systemd/resolved.conf", "a") as file: file.write(dns_resolver) print('wrote systemd-resolved config.') run('systemctl restart systemd-resolved') BPF_APP = r''' #include #include #include int dns_matching(struct __sk_buff *skb) { u8 *cursor = 0; struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet)); if (ethernet->type == ETH_P_IP) { // Checking that proto is UDP: struct ip_t *ip = cursor_advance(cursor, sizeof(*ip)); if (ip->nextp == IPPROTO_UDP) { // Check if the port is 53: struct udp_t *udp = cursor_advance(cursor, sizeof(*udp)); if (udp->dport == 53 || udp->sport == 53) { return -1; } } } return 0; } ''' default_rules = ''' iptables -F iptables -X ipset flush ipset destroy whitelist_hosts ipset destroy static_hosts ipset create whitelist_hosts hash:ip timeout 3600 ipset create static_hosts hash:ip iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT iptables -A INPUT ! -i lo -d 127.0.0.0/8 -j REJECT -m comment --comment "Drop all traffic to 127 that doesn't use lo" iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT iptables -A INPUT -i lo -j ACCEPT iptables -A OUTPUT -p tcp --match multiport --dports 30500:30600 -j ACCEPT iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT iptables -A OUTPUT -j REJECT iptables -P INPUT DROP iptables -P FORWARD DROP iptables -P OUTPUT DROP ''' #iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT #iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT for c in default_rules.splitlines(False): print(c.strip()) os.system(c.strip()) bpf = BPF(text=BPF_APP) function_dns_matching = bpf.load_func("dns_matching", BPF.SOCKET_FILTER) BPF.attach_raw_socket(function_dns_matching, '') socket_fd = function_dns_matching.sock fl = fcntl.fcntl(socket_fd, fcntl.F_GETFL) fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK) rules = {} cnames = {} current_cnames = {} blocked = [] def threaded(threadedfunc, inputd, threads=4): inputd = [x if isinstance(x,tuple) else tuple([x]) for x in inputd] import os from multiprocessing.pool import ThreadPool as Pool if threads == None: threads = os.sched_getaffinity(0) with Pool(threads) as p: result = p.starmap(threadedfunc, inputd) return result def regex_match(wl,hostname): import re match = re.fullmatch(wl, hostname) if match: print(' hostname: '+str(hostname) + ' in whitelist.') return True def cname_match(c,hostname): import datetime if c in current_cnames: if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now(): print(' hostname: '+str(hostname) + ' in cname.') return True def check_hostname(hostname,current_cnames,allow_list): import itertools if any(threaded(threadedfunc=regex_match, inputd=zip(allow_list,itertools.repeat(hostname)))): return True if any(threaded(threadedfunc=cname_match, inputd=zip(current_cnames.keys(),itertools.repeat(hostname)))): return True print(' hostname: '+str(hostname) + ' not permitted.') return False def valid_ip(address): try: socket.inet_aton(address) return True except: return False dns_list = watch_file('dns_whitelist.conf') host_list = watch_file('host_whitelist.conf') while True: time.sleep(.01) commands = [] #dnsl = dns_list.contents dns_list.check_change() if dns_list.contents: #diff = set(dnsl) - set(dns_list.contents) dns_regex = dns_list.contents #for block in dns_list.contents: # add_rule = 'nslookup '+str(block) # commands.append(add_rule) #threaded(commands) blocked = [] commands = [] host_list.check_change() if host_list.contents: threaded(run,['ipset flush static_hosts']) hosts = [x for x in host_list.contents if valid_ip(x)] for host in hosts: commands.append('ipset add static_hosts '+host) print(host) threaded(run,commands) try: r, w, e = select.select([ socket_fd ], [], [], 0) if socket_fd in r: try: packet_str = os.read(socket_fd, 2048) except KeyboardInterrupt: sys.exit(0) packet_bytearray = bytearray(packet_str) ETH_HLEN = 14 UDP_HLEN = 8 # IP header length ip_header_length = packet_bytearray[ETH_HLEN] ip_header_length = ip_header_length & 0x0F ip_header_length = ip_header_length << 2 # Starting the DNS packet payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN payload = packet_bytearray[payload_offset:] dnsrec = dnslib.DNSRecord.parse(payload) if dnsrec.rr: if True: for i in range(0, len(dnsrec.rr)): if str(dnsrec.rr[i].rtype) in ['5']: cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl+ 3600)} if str(dnsrec.rr[i].rtype) in ['1']: #rtype 28 is v6 #print(str(dnsrec.rr[i].rtype)) if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex): add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+3600) print(add_rule) commands.append(add_rule) threaded(run,commands) else: pass else: pass except Exception as e: print(e)