Переглянути джерело

Change vfn constraint system to use a solver

This allows us to specify virtual functions in a way that should
be robust against future patches.

We now fully require msgspec as part of the angr analysis, since
operating tagged unions is far easier that way.
nosoop 10 місяців тому
батько
коміт
f88017cf40
2 змінених файлів з 125 додано та 20 видалено
  1. 45 0
      README.md
  2. 80 20
      src/smgdc/angr/vtable_disamb.py

+ 45 - 0
README.md

@@ -57,3 +57,48 @@ Types of specifications include:
 - `bytesig`: confirms the presence of a byte sequence
 - `vfn`: takes a virtual method symbol and gets the vtable index for it on Linux and Windows
 (guesstimate on the latter)
+
+### Constraint file
+
+> [!IMPORTANT]
+> The constraint file specification is not finalized.  Please ensure that you are reading the
+> documentation for the version of `smgdc` that you are working with.
+
+Linux binaries are sometimes compiled with flags that cause a many-to-one mapping of symbols to
+function addresses (in other words: link time optimization).  While the application does attempt
+to uniquely identify symbols using other markers, there's sometimes insufficient information to
+do so.
+
+A constraint file is used as a last resort for end users to manually identify which symbols map
+to which positions for a given virtual table and its subclasses.  This is done by specifying
+relative ordering constraints for each symbol, allowing for reuse between binary revisions
+(with the assumption that vtables aren't reordered across them).
+
+Example:
+
+```
+[[_ZTV17CBaseCombatWeapon]]
+constraint = "soft"
+symbols = [
+	"_ZN17CBaseCombatWeapon6DeleteEv",
+	"_ZN17CBaseCombatWeapon4KillEv",
+]
+
+[[_ZTV17CBaseCombatWeapon]]
+constraint = "consecutive"
+symbols = [
+	"_ZN17CBaseCombatWeapon27WeaponRangeAttack1ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponRangeAttack2ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponMeleeAttack1ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponMeleeAttack2ConditionEff",
+]
+```
+
+In this `CBaseCombatWeapon` vtable, a *soft order* constraint is applied such that the offset
+of `CBaseCombatWeapon::Delete` is lower than `CBaseCombatWeapon::Kill` — both symbols
+point to the same address, and both are present at two offsets, so this ensures that once
+solved, each symbol will be assigned the correct offset.
+
+A *consecutive order* constraint is also applied such that
+`CBaseCombatWeapon::WeaponRangeAttack1Condition` and subsequent symbols have monotonically
+increasing offsets.

+ 80 - 20
src/smgdc/angr/vtable_disamb.py

@@ -7,7 +7,9 @@ import itertools
 import operator
 from typing import Iterable, NamedTuple
 
+import claripy
 import itanium_demangler as demangler
+import msgspec
 from cle.backends.symbol import Symbol
 from cle.memory import Clemory
 
@@ -79,9 +81,40 @@ VTable = list[Symbol]
 
 # some virtual functions cannot be disambiguated outright, or need to check the resolved
 # state of the rest of the table - for the former we have to disambiguate them out-of-band
-#
-# ideally we would do some linear constraint on ordering, but for now we just spec the offset
-VTableConstraintDict = dict[str, dict[str, int]]
+class BaseConstraint(msgspec.Struct, tag_field="constraint"):
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        raise NotImplementedError("BaseConstraint.apply is abstract")
+
+
+class SoftOrderConstraint(BaseConstraint, tag="soft"):
+    # a constraint where symbols must be in ascending order but gaps may be present
+    symbols: list[str]
+
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        if not all(x in symtable for x in self.symbols):
+            return
+        solver.add(
+            claripy.And(
+                *(symtable[b] > symtable[a] for a, b in itertools.pairwise(self.symbols))
+            )
+        )
+
+
+class ConsecutiveOrderConstraint(BaseConstraint, tag="consecutive"):
+    # a constraint where symbols must have offsets one higher than the previous
+    symbols: list[str]
+
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        if not all(x in symtable for x in self.symbols):
+            return
+        solver.add(
+            claripy.And(
+                *(symtable[b] == symtable[a] + 1 for a, b in itertools.pairwise(self.symbols))
+            )
+        )
+
+
+VTableConstraintDict = dict[str, list[SoftOrderConstraint | ConsecutiveOrderConstraint]]
 
 
 class VtableDisambiguator(angr.Analysis):
@@ -321,18 +354,6 @@ class VtableDisambiguator(angr.Analysis):
 
         return vt_out
 
-    def get_constrained_vfn(
-        self, vt: Symbol, vfnidx: int, candidate_fnsyms: set[Symbol]
-    ) -> Symbol | None:
-        for psym in self.superclass_map[vt]:
-            if psym.name not in self.vtable_constraint:
-                continue
-            vt_const = self.vtable_constraint[psym.name]
-            for fnsym in candidate_fnsyms:
-                if fnsym.name in vt_const and vt_const[fnsym.name] == vfnidx:
-                    return fnsym
-        return None
-
     def get_vfn_from_parent(self, vt: Symbol, vtidx: int) -> int:
         # HACK: some virtual destructors got optimized out and are represented by nullptrs
         #       in that case we try to match functions from the parent
@@ -369,11 +390,6 @@ class VtableDisambiguator(angr.Analysis):
                 fnsyms = set(self.syms_by_addr.get(vptr) or set())
 
                 if len(fnsyms) > 1:
-                    constrained_sym = self.get_constrained_vfn(vt, n, fnsyms)
-                    if constrained_sym:
-                        function_list.append(VTableFunction(table_index, {constrained_sym}))
-                        continue
-
                     # function in vtable is referenced by multiple names; perform disambiguation
                     matched_overload = self.resolve_ambiguous_vfn(
                         n, fnsyms, self.get_possible_vtable_set_candidates(vt, n)
@@ -405,6 +421,50 @@ class VtableDisambiguator(angr.Analysis):
                 )
                 vfn.possible_syms -= disambiguated_functions
 
+        if self.vtable_constraint:
+            solver = claripy.Solver()
+            solver_asts = {}
+            positional_constraints = collections.defaultdict(set)
+            for n, vfn in enumerate(function_list):
+                if len(vfn.possible_syms) == 1:
+                    sym = set(vfn.possible_syms).pop()
+                    # symbolic names for constant values - we don't use these ourselves; they're here
+                    # so user-defined tooling can specify them relative to variables as needed
+                    solver_asts[sym.name] = claripy.BVV(n, 10)
+                    continue
+
+                for sym in vfn.possible_syms:
+                    # constrain variable values to those defined as ambiguous
+                    positional_constraints[sym].add(n)
+
+            for sym, positions in positional_constraints.items():
+                if sym.name not in solver_asts:
+                    solver_asts[sym.name] = claripy.BVS(sym.name, 10)
+                solver.add(claripy.Or(*(solver_asts[sym.name] == n for n in positions)))
+
+            for pvt in self.superclass_map[vt]:
+                constraints = self.vtable_constraint.get(pvt.name)
+                if not constraints:
+                    continue
+                for constraint in constraints:
+                    constraint.apply(solver, solver_asts)
+
+            solutions = solver.batch_eval(solver_asts.values(), 2)
+            if len(solutions) == 1:
+                unique_solution = {v: k for k, v in zip(solver_asts.keys(), solutions[0])}
+                for n, vfn in enumerate(function_list):
+                    if (
+                        len(vfn.possible_syms) > 1
+                        and unique_solution
+                        and n in unique_solution.keys()
+                    ):
+                        resolved_syms = {
+                            sym
+                            for sym in iter(vfn.possible_syms)
+                            if sym.name == unique_solution[n]
+                        }
+                        vfn.possible_syms = resolved_syms
+
         # final uniqueness pass
         # raise exception if we still have multiple possible symbols for a given position
         for n, vfn in enumerate(function_list):