Browse Source

Only take deleting dtors when translating to Windows

nosoop 10 months ago
parent
commit
c0a7f893c1
3 changed files with 69 additions and 12 deletions
  1. 30 0
      src/smgdc/demangler_helpers.py
  2. 22 12
      src/smgdc/vtable.py
  3. 17 0
      tests/test_demangler.py

+ 30 - 0
src/smgdc/demangler_helpers.py

@@ -8,6 +8,36 @@
 import itanium_demangler as demangler
 
 
+def is_dtor(node: demangler.Node) -> bool:
+    def _extract(node: demangler.Node) -> bool:
+        match node:
+            case node if node.kind == "qual_name":
+                return _extract(node.value[-1])
+            case node if node.kind == "dtor":
+                return True
+        return False
+
+    if node.kind == "func":
+        return _extract(node.name)
+
+    raise ValueError(f"{node} is not a function")
+
+
+def get_dtor_type(node: demangler.Node) -> bool:
+    def _extract(node: demangler.Node) -> bool:
+        match node:
+            case node if node.kind == "qual_name":
+                return _extract(node.value[-1])
+            case node if node.kind == "dtor":
+                return node.value
+        raise ValueError(f"Unexpected node {node!r}")
+
+    if node.kind == "func":
+        return _extract(node.name)
+
+    raise ValueError(f"{node} is not a function")
+
+
 def extract_method_classname(node: demangler.Node) -> tuple[demangler.Node, ...]:
     def _extract(node: demangler.Node):
         match node:

+ 22 - 12
src/smgdc/vtable.py

@@ -49,7 +49,7 @@ def get_windows_vtables_from(bin: "LinuxBinary", vt: Symbol) -> VTable:
     vtda = bin.vtable_disambiguator
     vt_typeinfo = bin.angr.loader.memory.unpack_word(vt.rebased_addr + 0x4)
 
-    vt_parent_spans = [1]
+    vt_parent_spans = [0]
     for typeinfo_ptr, name in reversed(list(vtda.dump_class_parents(vt_typeinfo))):
         vt_parent = bin.angr.loader.find_symbol(f"_ZTV{name}")
         if not vt_parent:
@@ -80,7 +80,10 @@ def get_windows_vtables_from(bin: "LinuxBinary", vt: Symbol) -> VTable:
             # filter MI thunks
             dmsym = demangler.parse(sym.name)
             if dmsym:
-                if (
+                # MSVC only provides one dtor, so here we'll use the deleting one (D0)
+                if dh.is_dtor(dmsym) and dh.get_dtor_type(dmsym) != "deleting":
+                    continue
+                elif (
                     not demangler.is_ctor_or_dtor(dmsym)
                     and dh.extract_method_signature(dmsym) in thunk_fns
                 ):
@@ -124,11 +127,22 @@ def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
             # get symbols that map to that address
             if vptr == 0:
                 # HACK: some virtual destructors got optimized out and are represented by nullptrs
-                # TODO: it might be better to pull destructors from the parent
-                call_unexpected = bin.angr.loader.find_symbol("__cxa_call_unexpected")
-                assert call_unexpected
-                function_list.append(VTableFunction(table_index, {call_unexpected}))
-                continue
+                for parent_vt in bin.vtable_disambiguator.superclass_map[vt]:
+                    # HACK: in that case we try to match functions from the parent
+                    vptr = (
+                        bin.angr.loader.fast_memory_load_pointer(
+                            parent_vt.rebased_addr + (0x4 * n) + 0x8
+                        )
+                        or 0
+                    )
+                    if vptr:
+                        break
+
+                if vptr == 0:
+                    call_unexpected = bin.angr.loader.find_symbol("__cxa_call_unexpected")
+                    assert call_unexpected
+                    function_list.append(VTableFunction(table_index, {call_unexpected}))
+                    continue
 
             fnsyms = set(vtda.syms_by_addr.get(vptr) or set()) if vptr else set()
 
@@ -160,11 +174,7 @@ def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
             function_list.append(VTableFunction(table_index, fnsyms))
 
     for n, vfn in enumerate(function_list):
-        if n == 0:
-            # HACK: skip duplicated references to destructor
-            # we should be doing this at the disambiguation stage
-            continue
-        elif len(vfn.possible_syms) == 1:
+        if len(vfn.possible_syms) == 1:
             continue
 
         remaining_syms = vfn.possible_syms - disambiguated_functions

+ 17 - 0
tests/test_demangler.py

@@ -155,3 +155,20 @@ def test_function_overloaded(fnsyms: Iterable[str], expected: bool):
     node_syms = map(dh.extract_method_fname, map(dm.parse, fnsyms))
     for a, b in itertools.permutations(node_syms, 2):
         assert a == b
+
+
+@pytest.mark.parametrize(
+    "sym,expected",
+    [
+        (
+            "_ZN12CTFGameRulesD0Ev",
+            "deleting",
+        ),
+        (
+            "_ZN12CTFGameRulesD2Ev",
+            "base",
+        ),
+    ],
+)
+def test_function_dtor(sym: str, expected: str):
+    assert dh.get_dtor_type(dm.parse(sym)) == expected