livedm.py (13604B)
1 #!/usr/bin/python3 2 3 import gdb 4 import re 5 import json 6 from enum import IntEnum 7 8 # { allocator |-> register containing size argument } 9 break_arg = { 10 "kmem_cache_alloc_trace": "rdx", 11 "kmalloc_order": "rdi", 12 "__kmalloc": "rdi", 13 "vmalloc": "rdi", 14 "vzalloc": "rdi", 15 "vmalloc_user": "rdi", 16 "vmalloc_node": "rdi", 17 "vzalloc_node": "rdi", 18 "vmalloc_exec": "rdi", 19 "vmalloc_32": "rdi", 20 "vmalloc_32_user": "rdi", 21 } 22 23 # when the size is hidden in a struct, things get more complicated 24 # { allocator |-> (register with struct pointer, struct type, struct member that holds size) } 25 break_arg_access = { 26 "kmem_cache_alloc_node": ("rdi", "struct kmem_cache *", "object_size"), 27 } 28 29 # { type |-> [(access chain, critical value)] } 30 # 31 # Make sure each entry in an access chain (apart from the last entry) 32 # is a pointer, as it is dereferenced to obtain the next field 33 # 34 # If `critical_value` is set to None, any changes to the field are reported 35 watch_write_access_chain = { 36 "struct task_struct *": [ 37 # (((struct task_struct *)<address>)->real_cred)->uid 38 (["real_cred", "uid"], 0), 39 ] 40 } 41 42 # this is limited by the amount of debug registers... 43 avail_hw_breakpoints = 4 44 45 # store watchpoints so we can delete them later on (i.e., once the corresponding struct is freed) 46 watchpoints = {} 47 n_watchpoints = 0 48 49 # { memory freeing functions |-> register with argument } 50 free_funcs = { 51 "kfree": "rdi", 52 "vfree": "rdi", 53 "kmem_cache_free": "rsi", 54 } 55 56 entries = set() 57 exits = set() 58 types = {} 59 60 # { address |-> (type, size, call site) } 61 mem_map = {} 62 63 size_at_entry = None 64 65 class DebugLevel(IntEnum): 66 __order__ = 'WARN INFO TRACE' 67 WARN = 0 # warn when critical fields (e.g., task_struct->real_cred.uid) change to suspicious values 68 INFO = 1 # show watchpoint additions 69 TRACE = 2 # show every memory allocation 70 71 debug_level = DebugLevel.INFO 72 73 class RkPrintMem(gdb.Command): 74 """Print currently allocated memory""" 75 76 def __init__(self): 77 super(RkPrintMem, self).__init__("rk-print-mem", gdb.COMMAND_DATA) 78 79 def invoke(self, arg, from_tty): 80 global mem_map 81 82 if not mem_map: 83 return None 84 85 for addr, (type, size, caller) in mem_map.items(): 86 print(f"type: {type[7:]}, size: {size} B, address: {hex(addr)}, call site: {caller}") 87 88 RkPrintMem() 89 90 class RkDebug(gdb.Command): 91 """Toggle between different modes of memory logging""" 92 93 def __init__(self): 94 super(RkDebug, self).__init__("rk-debug", gdb.COMMAND_USER) 95 96 def invoke(self, arg, from_tty): 97 global debug_level 98 debug_level = DebugLevel((int(debug_level) + 1) % len(list(map(int, DebugLevel)))) 99 print(f"debug level set to {debug_level.name}") 100 101 RkDebug() 102 103 class RkPrintData(gdb.Command): 104 """Print data of a block in the memory map.\nUsage: rk-data <addr>""" 105 106 def __init__(self): 107 super(RkPrintData, self).__init__("rk-data", gdb.COMMAND_DATA) 108 109 def invoke(self, arg, from_tty): 110 global mem_map 111 112 if int(arg, 16) in mem_map: 113 (type, size, _) = mem_map[int(arg, 16)] 114 115 try: 116 data = gdb.execute(f"print *(({type[7:]}){arg})", to_string=True) 117 print(f"resolving {arg} to {type}\n") 118 print(data) 119 except: 120 print(f"could not resolve {type} at {arg}") 121 return 122 else: 123 print(f"{arg} does not point to the start of a kernel-allocated portion of the heap") 124 125 RkPrintData() 126 127 # this breakpoint can react to function entry and exit 128 class EntryExitBreakpoint(gdb.Breakpoint): 129 def __init__(self, b): 130 gdb.Breakpoint.__init__(self, b) 131 132 def stop(self): 133 global avail_hw_breakpoints 134 global watchpoints 135 global n_watchpoints 136 global watch_write_access_chain 137 global mem_map 138 139 frame = gdb.newest_frame() 140 141 if not frame.is_valid(): 142 return False 143 144 # FRAME_UNWIND_NO_REASON means the stack unwinding was successful 145 if frame.unwind_stop_reason() != gdb.FRAME_UNWIND_NO_REASON: 146 return False 147 148 # leverage statically-compiled dictionary to infer type and call site 149 typeret = self.type_lookup(frame) 150 151 if typeret is None: 152 return False 153 154 (type, caller) = typeret 155 156 # extract size and return value 157 extret = self.extract(frame) 158 159 if extret is None: 160 return False 161 162 (size, address) = extret 163 164 mem_map[address] = (type, size, caller) 165 166 # go over each watched-for type's access chains, 167 # setting watchpoints on the last accessed field of each chain 168 # 169 # we only do this when there are enough HW breakpoints available 170 if type[7:] in watch_write_access_chain: 171 access_chains = watch_write_access_chain[type[7:]] 172 for access_chain, critical_value in access_chains: 173 if n_watchpoints + len(access_chain) <= avail_hw_breakpoints: 174 watchpoint = WriteWatchpoint(address, type[7:], access_chain, critical_value) 175 176 if address in watchpoints: 177 watchpoints[address].append(watchpoint) 178 else: 179 watchpoints[address] = [watchpoint] 180 181 n_watchpoints += len(access_chain) 182 if n_watchpoints >= avail_hw_breakpoints: 183 break 184 185 if debug_level >= DebugLevel.TRACE: 186 print("Allocating", (type, size, caller), "at", hex(address)) 187 188 return False 189 190 def extract(self, frame): 191 global break_arg 192 global entries 193 global exits 194 global size_at_entry 195 196 # function entry: 197 if self.number in entries: 198 # extract size from correct register 199 if frame.name() in break_arg: 200 size = int(frame.read_register(break_arg[frame.name()])) 201 if size > 0: 202 size_at_entry = size 203 return None 204 205 # extract size from compound argument 206 elif frame.name() in break_arg_access: 207 (reg, type, field) = break_arg_access[frame.name()] 208 size = int(gdb.execute(f"p (({type})${reg})->{field}", 209 to_string=True).strip().split(" ")[2]) 210 211 if size > 0: 212 size_at_entry = size 213 return None 214 215 # function exit: 216 elif self.number in exits and size_at_entry is not None: 217 # extract return value, return tuple (size, address) 218 ret = (size_at_entry, int(frame.read_register('rax')) & (2 ** 64 - 1)) 219 size_at_entry = None 220 return ret 221 222 return None 223 224 def type_lookup(self, frame): 225 global types 226 227 f_iter = frame.older() 228 229 # iterate frame-by-frame up the stack 230 while f_iter is not None and f_iter.is_valid(): 231 sym = f_iter.find_sal() 232 symtab = sym.symtab 233 234 if symtab is None: 235 break 236 237 key = f"{symtab.filename}:{sym.line}" 238 239 if key in types: 240 return (types[key], key) 241 242 # https://stackoverflow.com/a/15550907/11069175 243 # https://stackoverflow.com/questions/41565105/gdb-breakpoint-gets-hit-in-the-wrong-line-number 244 # in rare cases, our lines don't match up due to optimizations 245 # therefore, we go one step in each direction (up to 10 times) until we find our type 246 else: 247 for i in range(1, 10): 248 key_pos = f"{symtab.filename}:{sym.line + i}" 249 key_neg = f"{symtab.filename}:{sym.line - i}" 250 251 if key_neg in types: 252 return (types[key_neg], key_neg) 253 254 if key_pos in types: 255 return (types[key_pos], key_pos) 256 257 f_iter = f_iter.older() 258 259 return None 260 261 class FreeBreakpoint(gdb.Breakpoint): 262 def __init__(self, b): 263 gdb.Breakpoint.__init__(self, b) 264 265 def stop(self): 266 global mem_map 267 global watchpoints 268 global n_watchpoints 269 global free_funcs 270 global debug_level 271 272 frame = gdb.newest_frame() 273 274 if not frame.is_valid(): 275 return False 276 277 address = int(frame.read_register(free_funcs[frame.name()])) & (2 ** 64 - 1) 278 279 if address is None: 280 return False 281 282 if address in watchpoints: 283 for watchpoint in watchpoints[address]: 284 if debug_level >= DebugLevel.INFO: 285 print("Deleting watchpoint on", watchpoint.current_chain) 286 287 watchpoint.delete() 288 n_watchpoints -= len(watchpoint.access_chain) 289 290 del(watchpoints[address]) 291 292 if address in mem_map: 293 if debug_level >= DebugLevel.TRACE: 294 print("Freeing", mem_map[address], "at", hex(address)) 295 mem_map.pop(address) 296 297 return False 298 299 class WriteWatchpoint(gdb.Breakpoint): 300 address = None 301 type = None 302 access_chain = None # ...(->...)*->[field we watch] 303 critical_value = None # value that, when written to watchpoint location, causes alert 304 previous_value = None # used to store previous value for comparison 305 previous_value_print = None # used for debug output 306 307 def __init__(self, address, type, access_chain, critical_value): 308 global watchpoints 309 310 self.address = address 311 self.type = type 312 self.access_chain = access_chain 313 self.critical_value = critical_value 314 315 current_chain = f"(({type}){hex(address)})" 316 for field in access_chain: 317 current_chain = "(" + current_chain + "->" + field + ")" 318 319 self.previous_value = self.get_value(current_chain) 320 self.previous_value_print = self.get_value_print(current_chain) 321 322 if debug_level >= DebugLevel.INFO: 323 print("Setting watchpoint on", current_chain) 324 325 self.current_chain = current_chain 326 gdb.Breakpoint.__init__(self, current_chain, internal=True, type=gdb.BP_WATCHPOINT) 327 328 def stop(self): 329 current_chain = f"(({self.type}){hex(self.address)})" 330 for field in self.access_chain: 331 current_chain = "(" + current_chain + "->" + field + ")" 332 333 current_value = self.get_value(current_chain) 334 current_value_print = self.get_value_print(current_chain) 335 336 if self.previous_value is not None and current_value is not None: 337 if self.previous_value != current_value: 338 if debug_level >= DebugLevel.INFO: 339 print(current_chain, "changed from", self.previous_value_print, 340 "to", current_value_print) 341 342 if debug_level >= DebugLevel.WARN: 343 current_value = int.from_bytes(bytes(current_value), "little") 344 if current_value == self.critical_value: 345 print(f"WARNING: critical value {self.critical_value} set to {current_chain}") 346 347 self.previous_value = current_value 348 self.previous_value_print = current_value_print 349 350 return False 351 352 def get_value_print(self, name): 353 try: 354 value_print = [line.strip() for line in 355 gdb.execute(f"p {name}", to_string=True).strip().split("\n")[1:-1]] 356 357 if len(value_print) > 1: 358 return "(" + " ".join(value_print) + ")" 359 else: 360 return value_print[0] 361 except: 362 return None 363 364 def get_value(self, name): 365 try: 366 size = int(gdb.parse_and_eval(f"sizeof({name})")) 367 address = int(gdb.execute(f"p &({name})", to_string = True).strip().split(" ")[-1], 16) 368 return gdb.selected_inferior().read_memory(address, size) 369 except: 370 return None 371 372 class Stage3(): 373 breakpoints = [] 374 375 dictfile = ".dict" 376 377 def __init__(self): 378 global break_arg 379 global entries 380 global exits 381 global types 382 383 # system can hang when pagination is on 384 gdb.execute("set pagination off") 385 386 # for printing structs with rk-data 387 gdb.execute("set print pretty on") 388 389 # load in pre-compiled type dictionary 390 with open(self.dictfile, 'r') as dct: 391 types = json.load(dct) 392 393 for b in (break_arg.keys() | break_arg_access.keys()): 394 # set breakpoint at function entry, to extract size 395 b_entry = EntryExitBreakpoint(b) 396 self.breakpoints.append(b_entry) 397 entries.add(b_entry.number) 398 399 # lookup offset from function entry to ret{,q}, account for possibility of >1 ret{,q} occurrence 400 disass = gdb.execute(f"disass {b}", to_string=True).strip().split("\n") 401 disass = [instr.split("\t") for instr in disass] 402 instrs = [(instr[0].strip(), instr[1].split(" ")[0].strip()) for instr in disass if len(instr) > 1] 403 retqs = [int(loc.split("<")[1].split(">")[0]) for (loc, instr) in instrs if instr == "ret" or instr == "retq"] 404 405 # set breakpoints at function exits (ret{,q}), to extract return value 406 for retq in retqs: 407 b_exit = EntryExitBreakpoint(f"*{hex(int(str(gdb.parse_and_eval(b).address).split(' ')[0], 16) + retq)}") 408 self.breakpoints.append(b_exit) 409 exits.add(b_exit.number) 410 411 for f in free_funcs: 412 FreeBreakpoint(f) 413 414 Stage3()