Update dnsf.py

multi threaded checking of rules.
This commit is contained in:
2025-10-31 08:31:07 -06:00
parent 3f38868381
commit f37b1f0bb1

170
dnsf.py
View File

@@ -2,14 +2,15 @@
from bcc import BPF from bcc import BPF
import dnslib import dnslib
import pprint
import fcntl import fcntl
import os import os
import sys,time,re import sys,time,re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from multiprocessing import Pool import select
import socket
def run(x): def run(x):
print(x)
import os import os
if type(x) == list: if type(x) == list:
for c in x: for c in x:
@@ -17,22 +18,14 @@ def run(x):
else: else:
os.system(x) os.system(x)
def threaded(commands):
with Pool(8) as p:
p.map(run, commands)
class watch_file(): class watch_file():
def __init__(self,fname): def __init__(self,fname):
import os
print('starting file watcher.') print('starting file watcher.')
self.filename = os.path.abspath(fname) self.filename = os.path.abspath(fname)
self.cstamp = os.stat(self.filename).st_mtime-1 self.cstamp = os.stat(self.filename).st_mtime-1
def check_change(self): def check_change(self):
self.contents = None self.contents = None
import os
import sys
import psutil
stamp = os.stat(self.filename).st_mtime stamp = os.stat(self.filename).st_mtime
if stamp != self.cstamp: if stamp != self.cstamp:
print('file change detected. loading.') print('file change detected. loading.')
@@ -45,6 +38,24 @@ class watch_file():
else: else:
return False return False
#systemd-resolved
resolved = """
[Resolve]
DNS=9.9.9.11#dns11.quad9.net
Domains=~.
DNSSEC=true
DNSOverTLS=yes
MulticastDNS=no
LLMNR=no
Cache=yes
DNSStubListener=yes
"""
with open("/etc/systemd/resolved.conf", "w") as file:
file.write(resolved)
print('wrote systemd-resolved config.')
run('systemctl restart systemd-resolved')
BPF_APP = r''' BPF_APP = r'''
#include <linux/if_ether.h> #include <linux/if_ether.h>
@@ -53,7 +64,6 @@ BPF_APP = r'''
int dns_matching(struct __sk_buff *skb) { int dns_matching(struct __sk_buff *skb) {
u8 *cursor = 0; u8 *cursor = 0;
// Checking the IP protocol: // Checking the IP protocol:
struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet)); struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet));
@@ -75,24 +85,32 @@ int dns_matching(struct __sk_buff *skb) {
''' '''
default_rules = ''' default_rules = '''
iptables -F
iptables -X iptables -X
ipset flush
ipset destroy whitelist_hosts ipset destroy whitelist_hosts
ipset destroy static_hosts ipset destroy static_hosts
ipset create whitelist_hosts hash:ip timeout 3600 ipset create whitelist_hosts hash:ip timeout 3600
ipset create static_hosts hash:ip ipset create static_hosts hash:ip
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT iptables -A INPUT ! -i lo -d 127.0.0.0/8 -j REJECT -m comment --comment "Drop all traffic to 127 that doesn't use lo"
iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A INPUT -i lo -j ACCEPT iptables -A INPUT -i lo -j ACCEPT
iptables -A OUTPUT -p tcp --match multiport --dports 30500:30600 -j ACCEPT
iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT iptables -A OUTPUT -m set --match-set whitelist_hosts dst -j ACCEPT
iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT iptables -A OUTPUT -m set --match-set static_hosts dst -j ACCEPT
iptables -A OUTPUT -j REJECT iptables -A OUTPUT -j REJECT
iptables -P INPUT DROP iptables -P INPUT DROP
iptables -P FORWARD DROP iptables -P FORWARD DROP
iptables -P OUTPUT DROP iptables -P OUTPUT DROP
systemd-resolve --flush-caches mkdir -p /etc/iptables/
iptables-save > /etc/iptables/rules.v4
''' '''
#iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
#iptables -A INPUT -s 127.0.0.0/8 -j ACCEPT
for c in default_rules.splitlines(False): for c in default_rules.splitlines(False):
print(c.strip()) print(c.strip())
os.system(c.strip()) os.system(c.strip())
@@ -107,26 +125,44 @@ fcntl.fcntl(socket_fd, fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
rules = {} rules = {}
cnames = {} cnames = {}
current_cnames = {}
blocked = [] blocked = []
def check_hostname(hostname,current_cnames,allow_list):
for wl in allow_list:
match = re.fullmatch(wl, hostname)
if match:
print('<check> hostname: '+str(hostname) + ' in whitelist')
return True
for c in current_cnames.keys():
def threaded(threadedfunc, inputd, threads=4):
inputd = [x if isinstance(x,tuple) else tuple([x]) for x in inputd]
import os
from multiprocessing.pool import ThreadPool as Pool
if threads == None:
threads = os.sched_getaffinity(0)
with Pool(threads) as p:
result = p.starmap(threadedfunc, inputd)
return result
def regex_match(wl,hostname):
import re
match = re.fullmatch(wl, hostname)
if match:
print('<check> hostname: '+str(hostname) + ' in whitelist.')
return True
def cname_match(c,hostname):
import datetime
if c in current_cnames:
if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now(): if hostname == current_cnames[c]["cname"] and current_cnames[c]["exp"] > datetime.now():
print('<check> hostname: '+str(hostname) + ' in cname') print('<check> hostname: '+str(hostname) + ' in cname.')
return True return True
def check_hostname(hostname,current_cnames,allow_list):
import itertools
if any(threaded(threadedfunc=regex_match, inputd=zip(allow_list,itertools.repeat(hostname)))):
return True
if any(threaded(threadedfunc=cname_match, inputd=zip(current_cnames.keys(),itertools.repeat(hostname)))):
return True
print('<check> hostname: '+str(hostname) + ' not permitted.') print('<check> hostname: '+str(hostname) + ' not permitted.')
blocked.append(hostname)
return False return False
import socket
def valid_ip(address): def valid_ip(address):
try: try:
socket.inet_aton(address) socket.inet_aton(address)
@@ -138,66 +174,70 @@ dns_list = watch_file('dns_whitelist.conf')
host_list = watch_file('host_whitelist.conf') host_list = watch_file('host_whitelist.conf')
while True: while True:
time.sleep(.05) time.sleep(.01)
commands = [] commands = []
#dnsl = dns_list.contents
dns_list.check_change() dns_list.check_change()
if dns_list.contents: if dns_list.contents:
#diff = set(dnsl) - set(dns_list.contents)
dns_regex = dns_list.contents dns_regex = dns_list.contents
for block in blocked: #for block in dns_list.contents:
add_rule = 'nslookup '+str(block) # add_rule = 'nslookup '+str(block)
commands.append(add_rule) # commands.append(add_rule)
threaded(commands) #threaded(commands)
blocked = [] blocked = []
commands = [] commands = []
host_list.check_change() host_list.check_change()
if host_list.contents: if host_list.contents:
threaded(['ipset flush static_hosts']) threaded(run,['ipset flush static_hosts'])
hosts = [x for x in host_list.contents if valid_ip(x)] hosts = [x for x in host_list.contents if valid_ip(x)]
for host in hosts: for host in hosts:
commands.append('ipset add static_hosts '+host) commands.append('ipset add static_hosts '+host)
print(host) print(host)
threaded(commands) threaded(run,commands)
import select try:
r, w, e = select.select([ socket_fd ], [], [], 0) r, w, e = select.select([ socket_fd ], [], [], 0)
if socket_fd in r: if socket_fd in r:
try: try:
packet_str = os.read(socket_fd, 2048) packet_str = os.read(socket_fd, 2048)
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(0) sys.exit(0)
packet_bytearray = bytearray(packet_str) packet_bytearray = bytearray(packet_str)
ETH_HLEN = 14 ETH_HLEN = 14
UDP_HLEN = 8 UDP_HLEN = 8
# IP header length # IP header length
ip_header_length = packet_bytearray[ETH_HLEN] ip_header_length = packet_bytearray[ETH_HLEN]
ip_header_length = ip_header_length & 0x0F ip_header_length = ip_header_length & 0x0F
ip_header_length = ip_header_length << 2 ip_header_length = ip_header_length << 2
# Starting the DNS packet # Starting the DNS packet
payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN payload_offset = ETH_HLEN + ip_header_length + UDP_HLEN
payload = packet_bytearray[payload_offset:] payload = packet_bytearray[payload_offset:]
dnsrec = dnslib.DNSRecord.parse(payload) dnsrec = dnslib.DNSRecord.parse(payload)
if dnsrec.rr: if dnsrec.rr:
if True: if True:
for i in range(0, len(dnsrec.rr)): for i in range(0, len(dnsrec.rr)):
if str(dnsrec.rr[i].rtype) in ['5']: if str(dnsrec.rr[i].rtype) in ['5']:
cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl)} cnames[str(dnsrec.questions[0].qname)] = {"cname":str(dnsrec.rr[i].rdata),"exp":datetime.now() + timedelta(seconds=dnsrec.rr[i].ttl+ 3600)}
if str(dnsrec.rr[i].rtype) in ['1']: if str(dnsrec.rr[i].rtype) in ['1']:
#rtype 28 is v6 #rtype 28 is v6
#print(str(dnsrec.rr[i].rtype)) #print(str(dnsrec.rr[i].rtype))
if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex): if check_hostname(hostname=str(dnsrec.questions[0].qname),current_cnames=cnames,allow_list=dns_regex):
add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+360) add_rule = 'ipset -exist add whitelist_hosts '+str(dnsrec.rr[i].rdata)+' timeout '+str(int(dnsrec.rr[i].ttl)+3600)
print(add_rule) print(add_rule)
commands.append(add_rule) commands.append(add_rule)
threaded(commands) threaded(run,commands)
else:
pass
else: else:
pass pass
else: except Exception as e:
pass print(e)