Browse Source

Begin migrating vtable helpers to disambiguator

nosoop 10 months ago
parent
commit
27f7fbb2a5
3 changed files with 187 additions and 176 deletions
  1. 175 0
      src/smgdc/angr/vtable_disamb.py
  2. 9 16
      src/smgdc/validate.py
  3. 3 160
      src/smgdc/vtable.py

+ 175 - 0
src/smgdc/angr/vtable_disamb.py

@@ -1,6 +1,7 @@
 #!/usr/bin/python3
 
 import collections
+import dataclasses
 import itertools
 import operator
 from typing import Iterable, NamedTuple
@@ -13,6 +14,9 @@ import angr
 
 from .. import demangler_helpers as dh
 
+# hotfix for demangler
+demangler._is_ctor_or_dtor = demangler.is_ctor_or_dtor
+
 
 def read_cstring(
     mem: Clemory, addr: int, encoding: str = "utf-8", chunk: int = 16, **kwargs
@@ -31,6 +35,34 @@ def read_cstring(
     return ""
 
 
+def _reorder_vfns_windows_estimate(symbols: list[Symbol], start_pos: int) -> list[Symbol]:
+    # reorders a given subclass-level slice of linux symbols to reflect windows ordering
+    name_buckets = collections.defaultdict(list)
+    for n, symbol in enumerate(symbols):
+        # collect overrides into buckets based on function name
+        dmsym = demangler.parse(symbol.name)
+        if dmsym:
+            name_buckets[dh.extract_method_fname(dmsym)].append(symbol)
+        else:
+            # HACK: preserves positions of references to __cxa_pure_virtual
+            name_buckets[(dmsym, n)].append(symbol)
+
+    # on windows, overloads are made consecutive and in reverse of declared order
+    # iteration order is guaranteed as of py3.7+ to be the insertion order,
+    # so this should output symbols otherwise in their original order
+    return list(itertools.chain.from_iterable(reversed(syms) for syms in name_buckets.values()))
+
+
+VTable = list[Symbol]
+
+
+# some virtual functions cannot be disambiguated outright, or need to check the resolved
+# state of the rest of the table - for the former we have to disambiguate them out-of-band
+#
+# ideally we would do some linear constraint on ordering, but for now we just spec the offset
+VTableConstraintDict = dict[str, dict[str, int]]
+
+
 class VtableDisambiguator(angr.Analysis):
     syms_by_addr: dict[int, set[Symbol]]
 
@@ -44,9 +76,12 @@ class VtableDisambiguator(angr.Analysis):
     # this way we can prune static functions from the candidate pool
     vtable_instances: set[demangler.Node]
 
+    vtable_constraint: VTableConstraintDict
+
     def __init__(self):
         self.loader = self.project.loader
         self.memory = self.loader.memory
+        self.vtable_constraint = {}
         self.analyze()
 
     def analyze(self) -> None:
@@ -230,5 +265,145 @@ class VtableDisambiguator(angr.Analysis):
             )
         ]
 
+    def get_windows_vtables_from(self, vt: Symbol) -> VTable:
+        vt_parent_spans = [0]
+        for vt_parent in reversed(self.superclass_map[vt]):
+            vt_parent_first, *_ = self.get_vfptrs_from_table(vt_parent)
+
+            # it's possible for a parent table to have more functions than its child
+            # e.g. CAutoGameSystemPerFrame has more functions present than CGameRules
+            span = len(vt_parent_first)
+            if span > vt_parent_spans[-1]:
+                vt_parent_spans.append(span)
+
+        vt_first, *vt_others = self.get_vtables_from_address(vt)
+        thunk_fns = set()
+        for vt_other in vt_others:
+            for sym in vt_other:
+                dmsym = demangler.parse(sym.name)
+                if dmsym and dmsym.kind == "nonvirt_thunk":
+                    thunk_fns.add(dh.extract_method_signature(dmsym))
+
+        vt_out = []
+        for vt_low, vt_high in itertools.pairwise(vt_parent_spans):
+            # we can only reorder overloads within the class they were initially specified
+            # e.g. CTFPlayer's ChangeTeam cannot be merged with CBaseEntity's
+            class_vfns = []
+
+            for sym in vt_first[vt_low:vt_high]:
+                # filter MI thunks
+                dmsym = demangler.parse(sym.name)
+                if dmsym:
+                    # MSVC only provides one dtor, so here we'll use the deleting one (D0)
+                    if dh.is_dtor(dmsym) and dh.get_dtor_type(dmsym) != "deleting":
+                        continue
+                    elif (
+                        not demangler.is_ctor_or_dtor(dmsym)
+                        and dh.extract_method_signature(dmsym) in thunk_fns
+                    ):
+                        continue
+                else:
+                    # __cxa_pure_virtual returns None here; we still add it to the vtable slice
+                    pass
+                class_vfns.append(sym)
+            vt_out.extend(_reorder_vfns_windows_estimate(class_vfns, vt_low))
+
+        return vt_out
+
+    def get_constrained_vfn(
+        self, vt: Symbol, vfnidx: int, candidate_fnsyms: set[Symbol]
+    ) -> Symbol | None:
+        for psym in self.superclass_map[vt]:
+            if psym.name not in self.vtable_constraint:
+                continue
+            vt_const = self.vtable_constraint[psym.name]
+            for fnsym in candidate_fnsyms:
+                if fnsym.name in vt_const and vt_const[fnsym.name] == vfnidx:
+                    return fnsym
+        return None
+
+    def get_vtables_from_address(self, vt: Symbol) -> list[VTable]:
+        # returns a list of vtables for each vtable present on the class
+        @dataclasses.dataclass
+        class VTableFunction:
+            tblidx: int
+            possible_syms: set[Symbol]
+
+        disambiguated_functions = set()
+        function_list: list[VTableFunction] = []
+        vptr_lists = self.get_vfptrs_from_table(vt)
+        for table_index, vptrs in enumerate(vptr_lists):
+            for n, vptr in enumerate(vptrs):
+                # get symbols that map to that address
+                if vptr == 0:
+                    # HACK: some virtual destructors got optimized out and are represented by nullptrs
+                    for parent_vt in self.superclass_map[vt]:
+                        # HACK: in that case we try to match functions from the parent
+                        vptr = (
+                            self.loader.fast_memory_load_pointer(
+                                parent_vt.rebased_addr + (0x4 * n) + 0x8
+                            )
+                            or 0
+                        )
+                        if vptr:
+                            break
+
+                    if vptr == 0:
+                        call_unexpected = self.loader.find_symbol("__cxa_call_unexpected")
+                        assert call_unexpected
+                        function_list.append(VTableFunction(table_index, {call_unexpected}))
+                        continue
+
+                fnsyms = set(self.syms_by_addr.get(vptr) or set()) if vptr else set()
+
+                if len(fnsyms) == 1:
+                    function_list.append(VTableFunction(table_index, fnsyms))
+                    continue
+                elif len(fnsyms) > 1:
+                    constrained_sym = self.get_constrained_vfn(vt, n, fnsyms)
+                    if constrained_sym:
+                        function_list.append(VTableFunction(table_index, {constrained_sym}))
+                        continue
+
+                    # function in vtable is referenced by multiple names; perform disambiguation
+                    matched_overload = self.resolve_ambiguous_vfn(
+                        n, fnsyms, self.get_possible_vtable_set_candidates(vt, n)
+                    )
+
+                    # it's possible that the other function(s) is/are resolveable.
+
+                    if matched_overload:
+                        # within a vtable we expect a non-extern symbol to resolvee exactly once,
+                        # so we can eliminate it from candidacy elsewhere
+                        disambiguated_functions.add(matched_overload)
+                        function_list.append(VTableFunction(table_index, {matched_overload}))
+                        continue
+
+                function_list.append(VTableFunction(table_index, fnsyms))
+
+        for n, vfn in enumerate(function_list):
+            if len(vfn.possible_syms) == 1:
+                continue
+
+            remaining_syms = vfn.possible_syms - disambiguated_functions
+            if len(remaining_syms) == 1:
+                vfn.possible_syms = remaining_syms
+                continue
+
+            # we should never receive an empty ``VTableFunction.possible_syms``
+            # for now we need to assert that a function address is unambiguous given the context
+            vt_name = demangler.parse(vt.name)
+            candidate_names = set(sym.name for sym in remaining_syms)
+            raise Exception(
+                f"Ambiguity in {vt_name} position {n}; candidates {candidate_names}"
+            )
+
+        return [
+            list(vfn.possible_syms.pop() for vfn in vtbl)
+            for tblidx, vtbl in itertools.groupby(
+                function_list, key=operator.attrgetter("tblidx")
+            )
+        ]
+
 
 angr.analyses.AnalysesHub.register_default("VtableDisambiguator", VtableDisambiguator)

+ 9 - 16
src/smgdc/validate.py

@@ -19,7 +19,7 @@ import angr
 import msgspec
 
 from . import vtable as vt_helpers
-from .angr.vtable_disamb import VtableDisambiguator
+from .angr.vtable_disamb import VTableConstraintDict, VtableDisambiguator
 from .types import ByteSequence, ByteSignature, Code, IntLiteral
 
 KEY_AS_IS = string.Template("${name}")
@@ -50,12 +50,6 @@ def convert_types(*types):
 # or bytesigs + offsets
 ResultValues = dict[string.Template, typing.Any]
 
-# some virtual functions cannot be disambiguated outright, or need to check the resolved
-# state of the rest of the table - for the former we have to disambiguate them out-of-band
-#
-# ideally we would do some linear constraint on ordering, but for now we just spec the offset
-VTableConstraintDict = dict[str, dict[str, int]]
-
 
 class BaseBinary:
     path: pathlib.Path
@@ -96,21 +90,20 @@ class WindowsBinary(BaseBinary):
 
 
 class LinuxBinary(BaseBinary):
-    vtable_constraint: VTableConstraintDict
+    constraints_file: pathlib.Path
 
     def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
         super().__init__(path, cache_path)
-        self.vtable_constraint = {}
-
-        constraints_file = (cache_path or pathlib.Path()) / f"{self.hash}.constraints.toml"
-        if constraints_file.exists():
-            self.vtable_constraint = msgspec.toml.decode(
-                constraints_file.read_bytes(), type=VTableConstraintDict
-            )
+        self.constraints_file = (cache_path or pathlib.Path()) / f"{self.hash}.constraints.toml"
 
     @functools.cached_property
     def vtable_disambiguator(self) -> VtableDisambiguator:
-        return self.angr.analyses.VtableDisambiguator()
+        disamb = self.angr.analyses.VtableDisambiguator()
+        if self.constraints_file.exists():
+            disamb.vtable_constraint = msgspec.toml.decode(
+                self.constraints_file.read_bytes(), type=VTableConstraintDict
+            )
+        return disamb
 
 
 PlatformBinary = WindowsBinary | LinuxBinary

+ 3 - 160
src/smgdc/vtable.py

@@ -3,182 +3,25 @@
 # vtable helpers
 # this should probably be cleaned up and moved somewhere else
 
-import collections
-import dataclasses
-import itertools
-import operator
 import typing
 
-import itanium_demangler as demangler
 from cle.backends.symbol import Symbol
 
-from . import demangler_helpers as dh
-
 if typing.TYPE_CHECKING:
     from .validate import LinuxBinary
 
 VTable = list[Symbol]
 
-# hotfix for demangler
-demangler._is_ctor_or_dtor = demangler.is_ctor_or_dtor
-
-
-def reorder_vfns_windows_estimate(symbols: list[Symbol], start_pos: int) -> list[Symbol]:
-    # reorders a given subclass-level slice of linux symbols to reflect windows ordering
-    name_buckets = collections.defaultdict(list)
-    for n, symbol in enumerate(symbols):
-        # collect overrides into buckets based on function name
-        dmsym = demangler.parse(symbol.name)
-        if dmsym:
-            name_buckets[dh.extract_method_fname(dmsym)].append(symbol)
-        else:
-            # HACK: preserves positions of references to __cxa_pure_virtual
-            name_buckets[(dmsym, n)].append(symbol)
-
-    # on windows, overloads are made consecutive and in reverse of declared order
-    # iteration order is guaranteed as of py3.7+ to be the insertion order,
-    # so this should output symbols otherwise in their original order
-    return list(itertools.chain.from_iterable(reversed(syms) for syms in name_buckets.values()))
-
 
 def get_windows_vtables_from(bin: "LinuxBinary", vt: Symbol) -> VTable:
-    vtda = bin.vtable_disambiguator
-
-    vt_parent_spans = [0]
-    for vt_parent in reversed(vtda.superclass_map[vt]):
-        vt_parent_first, *_ = vtda.get_vfptrs_from_table(vt_parent)
-
-        # it's possible for a parent table to have more functions than its child
-        # e.g. CAutoGameSystemPerFrame has more functions present than CGameRules
-        span = len(vt_parent_first)
-        if span > vt_parent_spans[-1]:
-            vt_parent_spans.append(span)
-
-    vt_first, *vt_others = get_vtables_from_address(bin, vt)
-    thunk_fns = set()
-    for vt_other in vt_others:
-        for sym in vt_other:
-            dmsym = demangler.parse(sym.name)
-            if dmsym and dmsym.kind == "nonvirt_thunk":
-                thunk_fns.add(dh.extract_method_signature(dmsym))
-
-    vt_out = []
-    for vt_low, vt_high in itertools.pairwise(vt_parent_spans):
-        # we can only reorder overloads within the class they were initially specified
-        # e.g. CTFPlayer's ChangeTeam cannot be merged with CBaseEntity's
-        class_vfns = []
-
-        for sym in vt_first[vt_low:vt_high]:
-            # filter MI thunks
-            dmsym = demangler.parse(sym.name)
-            if dmsym:
-                # MSVC only provides one dtor, so here we'll use the deleting one (D0)
-                if dh.is_dtor(dmsym) and dh.get_dtor_type(dmsym) != "deleting":
-                    continue
-                elif (
-                    not demangler.is_ctor_or_dtor(dmsym)
-                    and dh.extract_method_signature(dmsym) in thunk_fns
-                ):
-                    continue
-            else:
-                # __cxa_pure_virtual returns None here; we still add it to the vtable slice
-                pass
-            class_vfns.append(sym)
-        vt_out.extend(reorder_vfns_windows_estimate(class_vfns, vt_low))
-
-    return vt_out
+    return bin.vtable_disambiguator.get_windows_vtables_from(vt)
 
 
 def get_constrained_vfn(
     bin: "LinuxBinary", vt: Symbol, vfnidx: int, candidate_fnsyms: set[Symbol]
 ) -> Symbol | None:
-    for psym in bin.vtable_disambiguator.superclass_map[vt]:
-        if psym.name not in bin.vtable_constraint:
-            continue
-        vt_const = bin.vtable_constraint[psym.name]
-        for fnsym in candidate_fnsyms:
-            if fnsym.name in vt_const and vt_const[fnsym.name] == vfnidx:
-                return fnsym
-    return None
+    return bin.vtable_disambiguator.get_constrained_vfn(vt, vfnidx, candidate_fnsyms)
 
 
 def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
-    vtda = bin.vtable_disambiguator
-
-    # returns a list of vtables for each vtable present on the class
-    @dataclasses.dataclass
-    class VTableFunction:
-        tblidx: int
-        possible_syms: set[Symbol]
-
-    disambiguated_functions = set()
-    function_list: list[VTableFunction] = []
-    vptr_lists = vtda.get_vfptrs_from_table(vt)
-    for table_index, vptrs in enumerate(vptr_lists):
-        for n, vptr in enumerate(vptrs):
-            # get symbols that map to that address
-            if vptr == 0:
-                # HACK: some virtual destructors got optimized out and are represented by nullptrs
-                for parent_vt in bin.vtable_disambiguator.superclass_map[vt]:
-                    # HACK: in that case we try to match functions from the parent
-                    vptr = (
-                        bin.angr.loader.fast_memory_load_pointer(
-                            parent_vt.rebased_addr + (0x4 * n) + 0x8
-                        )
-                        or 0
-                    )
-                    if vptr:
-                        break
-
-                if vptr == 0:
-                    call_unexpected = bin.angr.loader.find_symbol("__cxa_call_unexpected")
-                    assert call_unexpected
-                    function_list.append(VTableFunction(table_index, {call_unexpected}))
-                    continue
-
-            fnsyms = set(vtda.syms_by_addr.get(vptr) or set()) if vptr else set()
-
-            if len(fnsyms) == 1:
-                function_list.append(VTableFunction(table_index, fnsyms))
-                continue
-            elif len(fnsyms) > 1:
-                constrained_sym = get_constrained_vfn(bin, vt, n, fnsyms)
-                if constrained_sym:
-                    function_list.append(VTableFunction(table_index, {constrained_sym}))
-                    continue
-
-                # function in vtable is referenced by multiple names; perform disambiguation
-                matched_overload = vtda.resolve_ambiguous_vfn(
-                    n, fnsyms, vtda.get_possible_vtable_set_candidates(vt, n)
-                )
-
-                # it's possible that the other function(s) is/are resolveable.
-
-                if matched_overload:
-                    # within a vtable we expect a non-extern symbol to resolvee exactly once,
-                    # so we can eliminate it from candidacy elsewhere
-                    disambiguated_functions.add(matched_overload)
-                    function_list.append(VTableFunction(table_index, {matched_overload}))
-                    continue
-
-            function_list.append(VTableFunction(table_index, fnsyms))
-
-    for n, vfn in enumerate(function_list):
-        if len(vfn.possible_syms) == 1:
-            continue
-
-        remaining_syms = vfn.possible_syms - disambiguated_functions
-        if len(remaining_syms) == 1:
-            vfn.possible_syms = remaining_syms
-            continue
-
-        # we should never receive an empty ``VTableFunction.possible_syms``
-        # for now we need to assert that a function address is unambiguous given the context
-        vt_name = demangler.parse(vt.name)
-        candidate_names = set(sym.name for sym in remaining_syms)
-        raise Exception(f"Ambiguity in {vt_name} position {n}; candidates {candidate_names}")
-
-    return [
-        list(vfn.possible_syms.pop() for vfn in vtbl)
-        for tblidx, vtbl in itertools.groupby(function_list, key=operator.attrgetter("tblidx"))
-    ]
+    return bin.vtable_disambiguator.get_vtables_from_address(vt)