linux-rootkit

Feature-rich interactive rootkit that targets Linux kernel 4.19, accompanied by a dynamic kernel memory analysis GDB plugin for in vivo introspection (e.g. using QEMU)
git clone git://git.deurzen.net/linux-rootkit
Log | Files | Refs

livedm.py (13604B)


      1 #!/usr/bin/python3
      2 
      3 import gdb
      4 import re
      5 import json
      6 from enum import IntEnum
      7 
      8 # { allocator |-> register containing size argument }
      9 break_arg = {
     10     "kmem_cache_alloc_trace": "rdx",
     11     "kmalloc_order": "rdi",
     12     "__kmalloc": "rdi",
     13     "vmalloc": "rdi",
     14     "vzalloc": "rdi",
     15     "vmalloc_user": "rdi",
     16     "vmalloc_node": "rdi",
     17     "vzalloc_node": "rdi",
     18     "vmalloc_exec": "rdi",
     19     "vmalloc_32": "rdi",
     20     "vmalloc_32_user": "rdi",
     21 }
     22 
     23 # when the size is hidden in a struct, things get more complicated
     24 # { allocator |-> (register with struct pointer, struct type, struct member that holds size) }
     25 break_arg_access = {
     26     "kmem_cache_alloc_node": ("rdi", "struct kmem_cache *", "object_size"),
     27 }
     28 
     29 # { type |-> [(access chain, critical value)] }
     30 #
     31 # Make sure each entry in an access chain (apart from the last entry)
     32 # is a pointer, as it is dereferenced to obtain the next field
     33 #
     34 # If `critical_value` is set to None, any changes to the field are reported
     35 watch_write_access_chain = {
     36     "struct task_struct *": [
     37         # (((struct task_struct *)<address>)->real_cred)->uid
     38         (["real_cred", "uid"], 0),
     39     ]
     40 }
     41 
     42 # this is limited by the amount of debug registers...
     43 avail_hw_breakpoints = 4
     44 
     45 # store watchpoints so we can delete them later on (i.e., once the corresponding struct is freed)
     46 watchpoints = {}
     47 n_watchpoints = 0
     48 
     49 # { memory freeing functions |-> register with argument }
     50 free_funcs = {
     51     "kfree": "rdi",
     52     "vfree": "rdi",
     53     "kmem_cache_free": "rsi",
     54 }
     55 
     56 entries = set()
     57 exits = set()
     58 types = {}
     59 
     60 # { address |-> (type, size, call site) }
     61 mem_map = {}
     62 
     63 size_at_entry = None
     64 
     65 class DebugLevel(IntEnum):
     66     __order__ = 'WARN INFO TRACE'
     67     WARN = 0  # warn when critical fields (e.g., task_struct->real_cred.uid) change to suspicious values
     68     INFO = 1  # show watchpoint additions
     69     TRACE = 2 # show every memory allocation
     70 
     71 debug_level = DebugLevel.INFO
     72 
     73 class RkPrintMem(gdb.Command):
     74     """Print currently allocated memory"""
     75 
     76     def __init__(self):
     77         super(RkPrintMem, self).__init__("rk-print-mem", gdb.COMMAND_DATA)
     78 
     79     def invoke(self, arg, from_tty):
     80         global mem_map
     81 
     82         if not mem_map:
     83             return None
     84 
     85         for addr, (type, size, caller) in mem_map.items():
     86             print(f"type: {type[7:]}, size: {size} B, address: {hex(addr)}, call site: {caller}")
     87 
     88 RkPrintMem()
     89 
     90 class RkDebug(gdb.Command):
     91     """Toggle between different modes of memory logging"""
     92 
     93     def __init__(self):
     94         super(RkDebug, self).__init__("rk-debug", gdb.COMMAND_USER)
     95 
     96     def invoke(self, arg, from_tty):
     97         global debug_level
     98         debug_level = DebugLevel((int(debug_level) + 1) % len(list(map(int, DebugLevel))))
     99         print(f"debug level set to {debug_level.name}")
    100 
    101 RkDebug()
    102 
    103 class RkPrintData(gdb.Command):
    104     """Print data of a block in the memory map.\nUsage: rk-data <addr>"""
    105 
    106     def __init__(self):
    107         super(RkPrintData, self).__init__("rk-data", gdb.COMMAND_DATA)
    108 
    109     def invoke(self, arg, from_tty):
    110         global mem_map
    111 
    112         if int(arg, 16) in mem_map:
    113             (type, size, _) = mem_map[int(arg, 16)]
    114 
    115             try:
    116                 data = gdb.execute(f"print *(({type[7:]}){arg})", to_string=True)
    117                 print(f"resolving {arg} to {type}\n")
    118                 print(data)
    119             except:
    120                 print(f"could not resolve {type} at {arg}")
    121                 return
    122         else:
    123             print(f"{arg} does not point to the start of a kernel-allocated portion of the heap")
    124 
    125 RkPrintData()
    126 
    127 # this breakpoint can react to function entry and exit
    128 class EntryExitBreakpoint(gdb.Breakpoint):
    129     def __init__(self, b):
    130         gdb.Breakpoint.__init__(self, b)
    131 
    132     def stop(self):
    133         global avail_hw_breakpoints
    134         global watchpoints
    135         global n_watchpoints
    136         global watch_write_access_chain
    137         global mem_map
    138 
    139         frame = gdb.newest_frame()
    140 
    141         if not frame.is_valid():
    142             return False
    143 
    144         # FRAME_UNWIND_NO_REASON means the stack unwinding was successful
    145         if frame.unwind_stop_reason() != gdb.FRAME_UNWIND_NO_REASON:
    146             return False
    147 
    148         # leverage statically-compiled dictionary to infer type and call site
    149         typeret = self.type_lookup(frame)
    150 
    151         if typeret is None:
    152             return False
    153 
    154         (type, caller) = typeret
    155 
    156         # extract size and return value
    157         extret = self.extract(frame)
    158 
    159         if extret is None:
    160             return False
    161 
    162         (size, address) = extret
    163 
    164         mem_map[address] = (type, size, caller)
    165 
    166         # go over each watched-for type's access chains,
    167         # setting watchpoints on the last accessed field of each chain
    168         #
    169         # we only do this when there are enough HW breakpoints available
    170         if type[7:] in watch_write_access_chain:
    171             access_chains = watch_write_access_chain[type[7:]]
    172             for access_chain, critical_value in access_chains:
    173                 if n_watchpoints + len(access_chain) <= avail_hw_breakpoints:
    174                     watchpoint = WriteWatchpoint(address, type[7:], access_chain, critical_value)
    175 
    176                     if address in watchpoints:
    177                         watchpoints[address].append(watchpoint)
    178                     else:
    179                         watchpoints[address] = [watchpoint]
    180 
    181                     n_watchpoints += len(access_chain)
    182                     if n_watchpoints >= avail_hw_breakpoints:
    183                         break
    184 
    185         if debug_level >= DebugLevel.TRACE:
    186             print("Allocating", (type, size, caller), "at", hex(address))
    187 
    188         return False
    189 
    190     def extract(self, frame):
    191         global break_arg
    192         global entries
    193         global exits
    194         global size_at_entry
    195 
    196         # function entry:
    197         if self.number in entries:
    198             # extract size from correct register
    199             if frame.name() in break_arg:
    200                 size = int(frame.read_register(break_arg[frame.name()]))
    201                 if size > 0:
    202                     size_at_entry = size
    203                     return None
    204 
    205             # extract size from compound argument
    206             elif frame.name() in break_arg_access:
    207                 (reg, type, field) = break_arg_access[frame.name()]
    208                 size = int(gdb.execute(f"p (({type})${reg})->{field}",
    209                                        to_string=True).strip().split(" ")[2])
    210 
    211                 if size > 0:
    212                     size_at_entry = size
    213                     return None
    214 
    215         # function exit:
    216         elif self.number in exits and size_at_entry is not None:
    217             # extract return value, return tuple (size, address)
    218             ret = (size_at_entry, int(frame.read_register('rax')) & (2 ** 64 - 1))
    219             size_at_entry = None
    220             return ret
    221 
    222         return None
    223 
    224     def type_lookup(self, frame):
    225         global types
    226 
    227         f_iter = frame.older()
    228 
    229         # iterate frame-by-frame up the stack
    230         while f_iter is not None and f_iter.is_valid():
    231             sym = f_iter.find_sal()
    232             symtab = sym.symtab
    233 
    234             if symtab is None:
    235                 break
    236 
    237             key = f"{symtab.filename}:{sym.line}"
    238 
    239             if key in types:
    240                 return (types[key], key)
    241 
    242             # https://stackoverflow.com/a/15550907/11069175
    243             # https://stackoverflow.com/questions/41565105/gdb-breakpoint-gets-hit-in-the-wrong-line-number
    244             # in rare cases, our lines don't match up due to optimizations
    245             # therefore, we go one step in each direction (up to 10 times) until we find our type
    246             else:
    247                 for i in range(1, 10):
    248                     key_pos = f"{symtab.filename}:{sym.line + i}"
    249                     key_neg = f"{symtab.filename}:{sym.line - i}"
    250 
    251                     if key_neg in types:
    252                         return (types[key_neg], key_neg)
    253 
    254                     if key_pos in types:
    255                         return (types[key_pos], key_pos)
    256 
    257             f_iter = f_iter.older()
    258 
    259         return None
    260 
    261 class FreeBreakpoint(gdb.Breakpoint):
    262     def __init__(self, b):
    263         gdb.Breakpoint.__init__(self, b)
    264 
    265     def stop(self):
    266         global mem_map
    267         global watchpoints
    268         global n_watchpoints
    269         global free_funcs
    270         global debug_level
    271 
    272         frame = gdb.newest_frame()
    273 
    274         if not frame.is_valid():
    275             return False
    276 
    277         address = int(frame.read_register(free_funcs[frame.name()])) & (2 ** 64 - 1)
    278 
    279         if address is None:
    280             return False
    281 
    282         if address in watchpoints:
    283             for watchpoint in watchpoints[address]:
    284                 if debug_level >= DebugLevel.INFO:
    285                     print("Deleting watchpoint on", watchpoint.current_chain)
    286 
    287                 watchpoint.delete()
    288                 n_watchpoints -= len(watchpoint.access_chain)
    289 
    290             del(watchpoints[address])
    291 
    292         if address in mem_map:
    293             if debug_level >= DebugLevel.TRACE:
    294                 print("Freeing", mem_map[address], "at", hex(address))
    295             mem_map.pop(address)
    296 
    297         return False
    298 
    299 class WriteWatchpoint(gdb.Breakpoint):
    300     address = None
    301     type = None
    302     access_chain = None          # ...(->...)*->[field we watch]
    303     critical_value = None        # value that, when written to watchpoint location, causes alert
    304     previous_value = None        # used to store previous value for comparison
    305     previous_value_print = None  # used for debug output
    306 
    307     def __init__(self, address, type, access_chain, critical_value):
    308         global watchpoints
    309 
    310         self.address = address
    311         self.type = type
    312         self.access_chain = access_chain
    313         self.critical_value = critical_value
    314 
    315         current_chain = f"(({type}){hex(address)})"
    316         for field in access_chain:
    317             current_chain = "(" + current_chain + "->" + field + ")"
    318 
    319         self.previous_value = self.get_value(current_chain)
    320         self.previous_value_print = self.get_value_print(current_chain)
    321 
    322         if debug_level >= DebugLevel.INFO:
    323             print("Setting watchpoint on", current_chain)
    324 
    325         self.current_chain = current_chain
    326         gdb.Breakpoint.__init__(self, current_chain, internal=True, type=gdb.BP_WATCHPOINT)
    327 
    328     def stop(self):
    329         current_chain = f"(({self.type}){hex(self.address)})"
    330         for field in self.access_chain:
    331             current_chain = "(" + current_chain + "->" + field + ")"
    332 
    333         current_value = self.get_value(current_chain)
    334         current_value_print = self.get_value_print(current_chain)
    335 
    336         if self.previous_value is not None and current_value is not None:
    337             if self.previous_value != current_value:
    338                 if debug_level >= DebugLevel.INFO:
    339                     print(current_chain, "changed from", self.previous_value_print,
    340                           "to", current_value_print)
    341 
    342                 if debug_level >= DebugLevel.WARN:
    343                     current_value = int.from_bytes(bytes(current_value), "little")
    344                     if current_value == self.critical_value:
    345                         print(f"WARNING: critical value {self.critical_value} set to {current_chain}")
    346 
    347         self.previous_value = current_value
    348         self.previous_value_print = current_value_print
    349 
    350         return False
    351 
    352     def get_value_print(self, name):
    353         try:
    354             value_print = [line.strip() for line in
    355                            gdb.execute(f"p {name}", to_string=True).strip().split("\n")[1:-1]]
    356 
    357             if len(value_print) > 1:
    358                 return "(" + " ".join(value_print) + ")"
    359             else:
    360                 return value_print[0]
    361         except:
    362             return None
    363 
    364     def get_value(self, name):
    365         try:
    366             size = int(gdb.parse_and_eval(f"sizeof({name})"))
    367             address = int(gdb.execute(f"p &({name})", to_string = True).strip().split(" ")[-1], 16)
    368             return gdb.selected_inferior().read_memory(address, size)
    369         except:
    370             return None
    371 
    372 class Stage3():
    373     breakpoints = []
    374 
    375     dictfile = ".dict"
    376 
    377     def __init__(self):
    378         global break_arg
    379         global entries
    380         global exits
    381         global types
    382 
    383         # system can hang when pagination is on
    384         gdb.execute("set pagination off")
    385 
    386         # for printing structs with rk-data
    387         gdb.execute("set print pretty on")
    388 
    389         # load in pre-compiled type dictionary
    390         with open(self.dictfile, 'r') as dct:
    391             types = json.load(dct)
    392 
    393         for b in (break_arg.keys() | break_arg_access.keys()):
    394             # set breakpoint at function entry, to extract size
    395             b_entry = EntryExitBreakpoint(b)
    396             self.breakpoints.append(b_entry)
    397             entries.add(b_entry.number)
    398 
    399             # lookup offset from function entry to ret{,q}, account for possibility of >1 ret{,q} occurrence
    400             disass = gdb.execute(f"disass {b}", to_string=True).strip().split("\n")
    401             disass = [instr.split("\t") for instr in disass]
    402             instrs = [(instr[0].strip(), instr[1].split(" ")[0].strip()) for instr in disass if len(instr) > 1]
    403             retqs = [int(loc.split("<")[1].split(">")[0]) for (loc, instr) in instrs if instr == "ret" or instr == "retq"]
    404 
    405             # set breakpoints at function exits (ret{,q}), to extract return value
    406             for retq in retqs:
    407                 b_exit = EntryExitBreakpoint(f"*{hex(int(str(gdb.parse_and_eval(b).address).split(' ')[0], 16) + retq)}")
    408                 self.breakpoints.append(b_exit)
    409                 exits.add(b_exit.number)
    410 
    411         for f in free_funcs:
    412             FreeBreakpoint(f)
    413 
    414 Stage3()