Update dnsf.py

multi threaded checking of rules.
This commit is contained in:
2025-10-31 08:31:07 -06:00
parent 3f38868381
commit f37b1f0bb1

170
dnsf.py
View File

@@ -2,14 +2,15 @@
from bcc import BPF
import dnslib
import pprint
import fcntl
import os
import sys,time,re
from datetime import datetime, timedelta
from multiprocessing import Pool
import select
import socket
def run(x):
print(x)
import os
if type(x) == list:
for c in x:
@@ -17,22 +18,14 @@ def run(x):
else:
os.system(x)
def threaded(commands):
with Pool(8) as p:
p.map(run, commands)
class watch_file():
def __init__(self,fname):
import os
print('starting file watcher.')
self.filename = os.path.abspath(fname)
self.cstamp = os.stat(self.filename).st_mtime-1
def check_change(self):
self.contents = None
import os
import sys
import psutil
stamp = os.stat(self.filename).st_mtime
if stamp != self.cstamp:
print('file change detected. loading.')
@@ -45,6 +38,24 @@ class watch_file():
else:
return False
#systemd-resolved
resolved = """
[Resolve]
DNS=9.9.9.11#dns11.quad9.net
Domains=~.
DNSSEC=true
DNSOverTLS=yes
MulticastDNS=no
LLMNR=no
Cache=yes
DNSStubListener=yes
"""
with open("/etc/systemd/resolved.conf", "w") as file:
file.write(resolved)
print('wrote systemd-resolved config.')
run('systemctl restart systemd-resolved')
BPF_APP = r'''
#include <linux/if_ether.h>
@@ -53,7 +64,6 @@ BPF_APP = r'''
int dns_matching(struct __sk_buff *skb) {
u8 *cursor = 0;
// Checking the IP protocol:
struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet));
@@ -75,24 +85,32 @@ int dns_matching(struct __sk_buff *skb) {
'''
default_rules = '''
iptables -F
iptables -X
ipset flush
ipset destroy whitelist_hosts
ipset destroy static_hosts
ipset create whitelist_hosts hash:ip timeout 3600
ipset create static_hosts hash:ip
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A INPUT ! -i lo -d 127.0.0.0/8 -j REJECT -m comment --comment "Drop all traffic to 127 that doesn't use lo"
iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A INPUT -i lo -j ACCEPT
iptables -A OUTPUT -p tcp --match multiport --dports 30500:30600 -j ACCEPT
iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT
iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT
iptables -A OUTPUT -j REJECT
iptables -P INPUT DROP
iptables -P FORWARD DROP
iptables -P OUTPUT DROP
systemd-resolve --flush-caches
mkdir -p /etc/iptables/
iptables-save > /etc/iptables/rules.v4
'''
#iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
#iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT
for c in default_rules.splitlines(False):
print(c.strip())
os.system(c.strip())
@@ -107,26 +125,44 @@ fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
rules = {}
cnames = {}
current_cnames = {}
blocked = []
def check_hostname(hostname,current_cnames,allow_list):
for wl in allow_list:
match = re.fullmatch(wl, hostname)
if match:
print('<check> hostname: '+str(hostname) + ' in whitelist')
return True
for c in current_cnames.keys():
def threaded(threadedfunc, inputd, threads=4):
inputd = [x if isinstance(x,tuple) else tuple([x]) for x in inputd]
import os
from multiprocessing.pool import ThreadPool as Pool
if threads == None:
threads = os.sched_getaffinity(0)
with Pool(threads) as p:
result = p.starmap(threadedfunc, inputd)
return result
def regex_match(wl,hostname):
import re
match = re.fullmatch(wl, hostname)
if match:
print('<check> hostname: '+str(hostname) + ' in whitelist.')
return True
def cname_match(c,hostname):
import datetime
if c in current_cnames:
if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now():
print('<check> hostname: '+str(hostname) + ' in cname')
print('<check> hostname: '+str(hostname) + ' in cname.')
return True
def check_hostname(hostname,current_cnames,allow_list):
import itertools
if any(threaded(threadedfunc=regex_match, inputd=zip(allow_list,itertools.repeat(hostname)))):
return True
if any(threaded(threadedfunc=cname_match, inputd=zip(current_cnames.keys(),itertools.repeat(hostname)))):
return True
print('<check> hostname: '+str(hostname) + ' not permitted.')
blocked.append(hostname)
return False
import socket
def valid_ip(address):
try:
socket.inet_aton(address)
@@ -138,66 +174,70 @@ dns_list = watch_file('dns_whitelist.conf')
host_list = watch_file('host_whitelist.conf')
while True:
time.sleep(.05)
time.sleep(.01)
commands = []
#dnsl = dns_list.contents
dns_list.check_change()
if dns_list.contents:
#diff = set(dnsl) - set(dns_list.contents)
dns_regex = dns_list.contents
for block in blocked:
add_rule = 'nslookup '+str(block)
commands.append(add_rule)
threaded(commands)
#for block in dns_list.contents:
# add_rule = 'nslookup '+str(block)
# commands.append(add_rule)
#threaded(commands)
blocked = []
commands = []
host_list.check_change()
if host_list.contents:
threaded(['ipset flush static_hosts'])
threaded(run,['ipset flush static_hosts'])
hosts = [x for x in host_list.contents if valid_ip(x)]
for host in hosts:
commands.append('ipset add static_hosts '+host)
print(host)
threaded(commands)
import select
r, w, e = select.select([ socket_fd ], [], [], 0)
if socket_fd in r:
try:
packet_str = os.read(socket_fd, 2048)
except KeyboardInterrupt:
sys.exit(0)
threaded(run,commands)
try:
r, w, e = select.select([ socket_fd ], [], [], 0)
if socket_fd in r:
try:
packet_str = os.read(socket_fd, 2048)
except KeyboardInterrupt:
sys.exit(0)
packet_bytearray = bytearray(packet_str)
packet_bytearray = bytearray(packet_str)
ETH_HLEN = 14
UDP_HLEN = 8
ETH_HLEN = 14
UDP_HLEN = 8
# IP header length
ip_header_length = packet_bytearray[ETH_HLEN]
ip_header_length = ip_header_length & 0x0F
ip_header_length = ip_header_length << 2
# IP header length
ip_header_length = packet_bytearray[ETH_HLEN]
ip_header_length = ip_header_length & 0x0F
ip_header_length = ip_header_length << 2
# Starting the DNS packet
payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN
payload = packet_bytearray[payload_offset:]
dnsrec = dnslib.DNSRecord.parse(payload)
if dnsrec.rr:
if True:
for i in range(0, len(dnsrec.rr)):
if str(dnsrec.rr[i].rtype) in ['5']:
cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl)}
# Starting the DNS packet
payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN
payload = packet_bytearray[payload_offset:]
dnsrec = dnslib.DNSRecord.parse(payload)
if dnsrec.rr:
if True:
for i in range(0, len(dnsrec.rr)):
if str(dnsrec.rr[i].rtype) in ['5']:
cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl+ 3600)}
if str(dnsrec.rr[i].rtype) in ['1']:
#rtype 28 is v6
#print(str(dnsrec.rr[i].rtype))
if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex):
add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+360)
print(add_rule)
commands.append(add_rule)
threaded(commands)
if str(dnsrec.rr[i].rtype) in ['1']:
#rtype 28 is v6
#print(str(dnsrec.rr[i].rtype))
if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex):
add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+3600)
print(add_rule)
commands.append(add_rule)
threaded(run,commands)
else:
pass
else:
pass
else:
pass
except Exception as e:
print(e)