Sfoglia il codice sorgente

Move vtable extraction logic to disambiguator

nosoop 10 mesi fa
parent
commit
29f6889a38
2 ha cambiato i file con 58 aggiunte e 31 eliminazioni
  1. 36 1
      src/smgdc/angr/vtable_disamb.py
  2. 22 30
      src/smgdc/vtable.py

+ 36 - 1
src/smgdc/angr/vtable_disamb.py

@@ -2,7 +2,8 @@
 
 import collections
 import itertools
-from typing import Iterable
+import operator
+from typing import Iterable, NamedTuple
 
 import itanium_demangler as demangler
 from cle.backends.symbol import Symbol
@@ -165,6 +166,9 @@ class VtableDisambiguator(angr.Analysis):
         # It's possible that this class inherited a base version of a method that is specialized
         # in a different part of the class hierarchy, so yield lists of subclasses of parents
         # that include the vtable index too.
+        #
+        # FIXME: This can iterate over unrelated classes; properly filter this down to the
+        # subclasses of the base that implemented this.
         max_vtsize = 4 * (vtidx + 2)
         for parent_vtsym in sorted(
             filter(lambda vt: vt.size > max_vtsize, self.superclass_map[vtsym]),
@@ -172,5 +176,36 @@ class VtableDisambiguator(angr.Analysis):
         ):
             yield self.subclass_map[parent_vtsym]
 
+    def get_vfptrs_from_table(self, vtsym: Symbol) -> list[list[int]]:
+        # returns a list of addresses for each vtable present on the class
+        class VTableFunction(NamedTuple):
+            tblidx: int
+            address: int
+
+        table_index = 0
+        function_list: list[VTableFunction] = []
+        vtable_range = enumerate(
+            range(vtsym.rebased_addr + 0x4 * 2, vtsym.rebased_addr + vtsym.size, 4)
+        )
+        for n, addr in vtable_range:
+            # get symbols that map to that address
+            deref = self.loader.fast_memory_load_pointer(addr)
+            fnsyms = set(self.syms_by_addr.get(deref) or set()) if deref else set()
+            if not fnsyms:
+                # vtable boundary; consume typeinfo so it doesn't get added to the list
+                # NOTE: we don't actually care if the indices skip; all that matters is that the
+                # functions are grouped correctly
+                table_index += 1
+                next(vtable_range)
+                continue
+            function_list.append(VTableFunction(table_index, deref))
+
+        return [
+            list(vfn.address for vfn in vtbl)
+            for tblidx, vtbl in itertools.groupby(
+                function_list, key=operator.attrgetter("tblidx")
+            )
+        ]
+
 
 angr.analyses.AnalysesHub.register_default("VtableDisambiguator", VtableDisambiguator)

+ 22 - 30
src/smgdc/vtable.py

@@ -96,40 +96,32 @@ def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
         tblidx: int
         possible_syms: set[Symbol]
 
-    table_index = 0
     function_list: list[VTableFunction] = []
-    vtable_range = enumerate(range(vt.rebased_addr + 0x4 * 2, vt.rebased_addr + vt.size, 4))
-    for n, addr in vtable_range:
-        # get symbols that map to that address
-        deref = bin.angr.loader.fast_memory_load_pointer(addr)
-        fnsyms = set(vtda.syms_by_addr.get(deref) or set()) if deref else set()
-        if not fnsyms:
-            # vtable boundary; consume typeinfo so it doesn't get added to the list
-            # NOTE: we don't actually care if the indices skip; all that matters is that the
-            # functions are grouped correctly
-            table_index += 1
-            next(vtable_range)
-            continue
-
-        if len(fnsyms) == 1:
-            function_list.append(VTableFunction(table_index, fnsyms))
-            continue
-        elif len(fnsyms) > 1:
-            # function in vtable is referenced by multiple names; perform disambiguation
-            matched_overload = None
-            for related in vtda.get_possible_vtable_set_candidates(vt, n):
-                matched_overload = vtda.resolve_ambiguous_vfn(n, fnsyms, related)
-                if matched_overload:
-                    break
+    vptr_lists = vtda.get_vfptrs_from_table(vt)
+    for table_index, vptrs in enumerate(vptr_lists):
+        for n, vptr in enumerate(vptrs):
+            # get symbols that map to that address
+            fnsyms = set(vtda.syms_by_addr.get(vptr) or set()) if vptr else set()
+
+            if len(fnsyms) == 1:
+                function_list.append(VTableFunction(table_index, fnsyms))
+                continue
+            elif len(fnsyms) > 1:
+                # function in vtable is referenced by multiple names; perform disambiguation
+                matched_overload = None
+                for related in vtda.get_possible_vtable_set_candidates(vt, n):
+                    matched_overload = vtda.resolve_ambiguous_vfn(n, fnsyms, related)
+                    if matched_overload:
+                        break
 
-            # it's possible that the other function(s) is/are resolveable.
-            # without doing multiple passes and saving the disambiguity somewhere it'll be difficult to match
+                # it's possible that the other function(s) is/are resolveable.
+                # without doing multiple passes and saving the disambiguity somewhere it'll be difficult to match
 
-            if matched_overload:
-                function_list.append(VTableFunction(table_index, {matched_overload}))
-                continue
+                if matched_overload:
+                    function_list.append(VTableFunction(table_index, {matched_overload}))
+                    continue
 
-        function_list.append(VTableFunction(table_index, fnsyms))
+            function_list.append(VTableFunction(table_index, fnsyms))
 
     for n, vfn in enumerate(function_list):
         if n == 0: