From f37b1f0bb1fde1daf4c6c466d6462e1bb223c4d6 Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 31 Oct 2025 08:31:07 -0600 Subject: [PATCH] Update dnsf.py multi threaded checking of rules. --- dnsf.py | 192 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 116 insertions(+), 76 deletions(-) diff --git a/dnsf.py b/dnsf.py index 8be89ca..37fe55b 100644 --- a/dnsf.py +++ b/dnsf.py @@ -2,14 +2,15 @@ from bcc import BPF import dnslib -import pprint import fcntl import os import sys,time,re from datetime import datetime, timedelta -from multiprocessing import Pool +import select +import socket def run(x): + print(x) import os if type(x) == list: for c in x: @@ -17,22 +18,14 @@ def run(x): else: os.system(x) -def threaded(commands): - with Pool(8) as p: - p.map(run, commands) - class watch_file(): def __init__(self,fname): - import os print('starting file watcher.') self.filename = os.path.abspath(fname) self.cstamp = os.stat(self.filename).st_mtime-1 def check_change(self): self.contents = None - import os - import sys - import psutil stamp = os.stat(self.filename).st_mtime if stamp != self.cstamp: print('file change detected. loading.') @@ -45,6 +38,24 @@ class watch_file(): else: return False +#systemd-resolved + +resolved = """ +[Resolve] +DNS=9.9.9.11#dns11.quad9.net +Domains=~. +DNSSEC=true +DNSOverTLS=yes +MulticastDNS=no +LLMNR=no +Cache=yes +DNSStubListener=yes +""" + +with open("/etc/systemd/resolved.conf", "w") as file: + file.write(resolved) +print('wrote systemd-resolved config.') +run('systemctl restart systemd-resolved') BPF_APP = r''' #include @@ -53,7 +64,6 @@ BPF_APP = r''' int dns_matching(struct __sk_buff *skb) { u8 *cursor = 0; - // Checking the IP protocol: struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet)); @@ -75,24 +85,32 @@ int dns_matching(struct __sk_buff *skb) { ''' default_rules = ''' +iptables -F iptables -X +ipset flush ipset destroy whitelist_hosts ipset destroy static_hosts ipset create whitelist_hosts hash:ip timeout 3600 ipset create static_hosts hash:ip iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT -iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT +iptables -A INPUT ! -i lo -d 127.0.0.0/8 -j REJECT -m comment --comment "Drop all traffic to 127 that doesn't use lo" iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT +iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT iptables -A INPUT -i lo -j ACCEPT +iptables -A OUTPUT -p tcp --match multiport --dports 30500:30600 -j ACCEPT iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT iptables -A OUTPUT -j REJECT iptables -P INPUT DROP iptables -P FORWARD DROP iptables -P OUTPUT DROP -systemd-resolve --flush-caches +mkdir -p /etc/iptables/ +iptables-save > /etc/iptables/rules.v4 ''' +#iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT +#iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT + for c in default_rules.splitlines(False): print(c.strip()) os.system(c.strip()) @@ -107,25 +125,43 @@ fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK) rules = {} cnames = {} +current_cnames = {} blocked = [] -def check_hostname(hostname,current_cnames,allow_list): - for wl in allow_list: - match = re.fullmatch(wl, hostname) - if match: - print(' hostname: '+str(hostname) + ' in whitelist') - return True - - for c in current_cnames.keys(): - if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now(): - print(' hostname: '+str(hostname) + ' in cname') - return True - - print(' hostname: '+str(hostname) + ' not permitted.') - blocked.append(hostname) - return False -import socket + +def threaded(threadedfunc, inputd, threads=4): + inputd = [x if isinstance(x,tuple) else tuple([x]) for x in inputd] + import os + from multiprocessing.pool import ThreadPool as Pool + if threads == None: + threads = os.sched_getaffinity(0) + with Pool(threads) as p: + result = p.starmap(threadedfunc, inputd) + return result + +def regex_match(wl,hostname): + import re + match = re.fullmatch(wl, hostname) + if match: + print(' hostname: '+str(hostname) + ' in whitelist.') + return True + +def cname_match(c,hostname): + import datetime + if c in current_cnames: + if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now(): + print(' hostname: '+str(hostname) + ' in cname.') + return True + +def check_hostname(hostname,current_cnames,allow_list): + import itertools + if any(threaded(threadedfunc=regex_match, inputd=zip(allow_list,itertools.repeat(hostname)))): + return True + if any(threaded(threadedfunc=cname_match, inputd=zip(current_cnames.keys(),itertools.repeat(hostname)))): + return True + print(' hostname: '+str(hostname) + ' not permitted.') + return False def valid_ip(address): try: @@ -138,66 +174,70 @@ dns_list = watch_file('dns_whitelist.conf') host_list = watch_file('host_whitelist.conf') while True: - time.sleep(.05) + time.sleep(.01) commands = [] - + #dnsl = dns_list.contents dns_list.check_change() if dns_list.contents: + + #diff = set(dnsl) - set(dns_list.contents) dns_regex = dns_list.contents - for block in blocked: - add_rule = 'nslookup '+str(block) - commands.append(add_rule) - threaded(commands) + #for block in dns_list.contents: + # add_rule = 'nslookup '+str(block) + # commands.append(add_rule) + #threaded(commands) blocked = [] commands = [] host_list.check_change() if host_list.contents: - threaded(['ipset flush static_hosts']) + threaded(run,['ipset flush static_hosts']) hosts = [x for x in host_list.contents if valid_ip(x)] for host in hosts: commands.append('ipset add static_hosts '+host) print(host) - threaded(commands) - import select - r, w, e = select.select([ socket_fd ], [], [], 0) - if socket_fd in r: - try: - packet_str = os.read(socket_fd, 2048) - except KeyboardInterrupt: - sys.exit(0) - - packet_bytearray = bytearray(packet_str) - - ETH_HLEN = 14 - UDP_HLEN = 8 - - # IP header length - ip_header_length = packet_bytearray[ETH_HLEN] - ip_header_length = ip_header_length & 0x0F - ip_header_length = ip_header_length << 2 - - # Starting the DNS packet - payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN - payload = packet_bytearray[payload_offset:] - dnsrec = dnslib.DNSRecord.parse(payload) - if dnsrec.rr: - if True: - for i in range(0, len(dnsrec.rr)): - if str(dnsrec.rr[i].rtype) in ['5']: - cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl)} - - if str(dnsrec.rr[i].rtype) in ['1']: - #rtype 28 is v6 - #print(str(dnsrec.rr[i].rtype)) - if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex): - add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+360) - print(add_rule) - commands.append(add_rule) - threaded(commands) + threaded(run,commands) + try: + r, w, e = select.select([ socket_fd ], [], [], 0) + if socket_fd in r: + try: + packet_str = os.read(socket_fd, 2048) + except KeyboardInterrupt: + sys.exit(0) + + packet_bytearray = bytearray(packet_str) + + ETH_HLEN = 14 + UDP_HLEN = 8 + + # IP header length + ip_header_length = packet_bytearray[ETH_HLEN] + ip_header_length = ip_header_length & 0x0F + ip_header_length = ip_header_length << 2 + + # Starting the DNS packet + payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN + payload = packet_bytearray[payload_offset:] + dnsrec = dnslib.DNSRecord.parse(payload) + if dnsrec.rr: + if True: + for i in range(0, len(dnsrec.rr)): + if str(dnsrec.rr[i].rtype) in ['5']: + cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl+ 3600)} + + if str(dnsrec.rr[i].rtype) in ['1']: + #rtype 28 is v6 + #print(str(dnsrec.rr[i].rtype)) + if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex): + add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+3600) + print(add_rule) + commands.append(add_rule) + threaded(run,commands) + else: + pass else: pass - else: - pass - + except Exception as e: + print(e) + \ No newline at end of file