commit 54966ee319deae1c4a11f57c2e56e0db7b93ade8
parent 42cf3ffe4a78d79c2fc49a166971671e24cb6024
Author: deurzen <m.deurzen@tum.de>
Date:   Sun, 24 Jan 2021 09:25:02 +0100
adds initial relatext data retrieval
Diffstat:
2 files changed, 83 insertions(+), 18 deletions(-)
diff --git a/mem_forensics/memcheck-gdb.py b/mem_forensics/memcheck-gdb.py
@@ -588,6 +588,12 @@ class RkCheckFunctions(gdb.Command):
     symbols = None
     headers = None
 
+    use_memoization = False
+
+    same_count = 0
+    diff_count = 0
+    skip_count = 0
+
     #Key: symbol, value: tuple (size, code bytes from ELF file)
     code_dict = {}
 
@@ -595,6 +601,8 @@ class RkCheckFunctions(gdb.Command):
     altinstr_dict = {}
     paravirt_dict = {}
 
+    relatext_dict = {}
+
     def __init__(self):
         super(RkCheckFunctions, self).__init__("rk-check-functions", gdb.COMMAND_USER, gdb.COMMAND_DATA)
 
@@ -606,15 +614,19 @@ class RkCheckFunctions(gdb.Command):
             print("no object file has been read in to calculate offsets, please run `rk-load-symbols` first")
             return None
 
-        found = False
-        for root, _, files in os.walk("."):
-            if "xbfunc.gdb" in files:
-                found = True
-                gdb.execute(f'source {os.path.join(root, "xbfunc.gdb")}')
+        md5sum = subprocess.check_output(f"md5sum {file_g}", shell=True).split()[0]
 
-        if not found:
-            print("could not locate the `xbfunc.gdb` file that is required to perform the function check")
-            return None
+        path = None
+        for root, dirs, files in os.walk("."):
+            if "runtime" in dirs:
+                path = os.path.join(root, f"runtime")
+                break
+
+        if path:
+            with open(f"{path}/md5sum") as f:
+                if md5sum.decode(sys.stdout.encoding) == f.readline().strip():
+                    print("using memoized ELF data stored in `runtime/{func,altinstr,paravirt}`")
+                    self.use_memoization = True
 
         self.f = elffile.ELFFile(open(file_g, "rb"))
         self.s = self.f.get_section_by_name(".symtab")
@@ -624,33 +636,40 @@ class RkCheckFunctions(gdb.Command):
         self.fill_code_dict()
         self.fill_altinstr_dict()
         self.fill_paravirt_dict()
+        self.fill_relatext_dict()
         print(" done!")
 
         print("comparing functions...", end='', flush=True)
         self.compare_functions()
         print(" done!")
 
+        print(f"{self.diff_count} functions differ, {self.same_count} are equal, {self.skip_count} skipped")
+
     def fill_code_dict(self):
         for i, symbol in enumerate(self.s.iter_symbols()):
-            # if i < 30195:
-            #     continue
-            if i > 2000:
+            if i < 800:
+                continue
+            if i > 1000:
                 break
 
             if symbol.entry["st_info"]["type"] == "STT_FUNC":
                 name = symbol.name
                 size = symbol.entry["st_size"]
             else:
+                self.skip_count += 1
                 continue
 
             if name is None or ".cold." in name or ".part." in name or ".constprop." in name:
+                self.skip_count += 1
                 continue
 
             if size is None or size == 0:
+                self.skip_count += 1
                 continue
 
             addr = self.get_v_addr(name)
             if addr is None:
+                self.skip_count += 1
                 continue
 
             objdump = subprocess.check_output(f"objdump -z --disassemble={name} {file_g}", shell=True)
@@ -720,19 +739,26 @@ class RkCheckFunctions(gdb.Command):
             else:
                 self.altinstr_dict[key] = [value]
 
-            i = i + alt_instr_sz
+            i += alt_instr_sz
 
     def fill_paravirt_dict(self):
         global file_g
         global v_off_g
 
         # paravirt_patch_site layout (read from elf section .parainstructions, size with padding: 16 bytes):
-        # .quad instr          <-- Adress to instruction = instr + v_off_G
+        # .quad instr          <-- Adress to instruction = instr + v_off_g
         # .byte instrtype
         # .byte len
         # .short clobbers
         # 4 byte padding
 
+        # struct paravirt_patch_site {
+	#    u8   *instr;     /* original instructions */
+	#    u8   instrtype;  /* type of this instruction */
+	#    u8   len;        /* length of original instruction */
+	#    u16  clobbers;   /* what registers you may clobber */
+        #};
+
         sec = self.f.get_section_by_name(".parainstructions")
         data = sec.data()
 
@@ -759,7 +785,33 @@ class RkCheckFunctions(gdb.Command):
             else:
                 self.paravirt_dict[key] = [value]
 
-            i = i + paravirt_patch_site_sz
+            i += paravirt_patch_site_sz
+
+    def fill_relatext_dict(self):
+        global file_g
+        global v_off_g
+
+        # typedef __u64	Elf64_Addr;
+        # typedef __u64	Elf64_Xword;
+        # typedef __s64	Elf64_Sxword;
+        #
+        # typedef struct elf64_rela {
+        #   Elf64_Addr    r_offset;  /* Location at which to apply the action */
+        #   Elf64_Xword   r_info;    /* index and type of relocation */
+        #   Elf64_Sxword  r_addend;  /* Constant addend used to compute value */
+        # } Elf64_Rela;
+
+        sec = self.f.get_section_by_name(".rela.text")
+        data = sec.data()
+
+        for reloc in sec.iter_relocations():
+            addr = reloc['r_offset'] + v_off_g
+            info = reloc['r_info']
+            addend = reloc['r_addend']
+
+            print('offset = %s' % hex(addr))
+            print('info = %s' % hex(info))
+            print('addend = %s' % hex(addend))
 
     def compare_functions(self):
         for name, (size, elf_bytes) in self.code_dict.items():
@@ -785,24 +837,36 @@ class RkCheckFunctions(gdb.Command):
             if len(live_bytes) > 1 and live_bytes[0:2] == "cc":
                 int3_chain = ''.join('c' * len(live_bytes))
                 if live_bytes == int3_chain:
+                    self.skip_count += 1
                     return
 
             if len(live_bytes) > 1 and live_bytes[0:2] == "00":
                 null_chain = ''.join('0' * len(live_bytes))
                 if live_bytes == null_chain:
+                    self.skip_count += 1
                     return
 
-            to_exclude_paravirt = [l for r in self.paravirt_dict[name] for l in list(r)] if name in self.paravirt_dict else []
-            to_exclude_altinstr = [l for r in self.altinstr_dict[name] for l in list(r)] if name in self.altinstr_dict else []
+            to_exclude_paravirt = [l for r in self.paravirt_dict[name]
+                                   for l in list(r)] if name in self.paravirt_dict else []
+
+            to_exclude_altinstr = [l for r in self.altinstr_dict[name]
+                                   for l in list(r)] if name in self.altinstr_dict else []
 
             to_exclude += to_exclude_paravirt + to_exclude_altinstr
+
             if to_exclude:
-                elf_bytes = "".join([elf_byte for i, elf_byte in enumerate(elf_bytes) if i not in to_exclude])
-                live_bytes = "".join([elf_byte for i, elf_byte in enumerate(live_bytes) if i not in to_exclude])
+                elf_bytes = "".join([elf_byte for i, elf_byte in enumerate(elf_bytes)
+                                     if i not in to_exclude])
+
+                live_bytes = "".join([elf_byte for i, elf_byte in enumerate(live_bytes)
+                                      if i not in to_exclude])
 
             if live_bytes != elf_bytes:
+                self.diff_count += 1
                 print(f"function `{name}` compromised, live bytes not equal to ELF bytes")
                 print(f"excluded: {to_exclude}, expected: {elf_bytes}, live: {live_bytes}")
+            else:
+                self.same_count += 1
 
     def get_v_addr(self, symbol):
         try:
diff --git a/runtime/md5sum b/runtime/md5sum
@@ -0,0 +1 @@
+15f3d058efb6b51899fb6a6175b565f9