Parcourir la source

Change vfn constraint system to use a solver

This allows us to specify virtual functions in a way that should
be robust against future patches.

We now fully require msgspec as part of the angr analysis, since
operating tagged unions is far easier that way.
nosoop il y a 10 mois
Parent
commit
f88017cf40
2 fichiers modifiés avec 125 ajouts et 20 suppressions
  1. 45 0
      README.md
  2. 80 20
      src/smgdc/angr/vtable_disamb.py

+ 45 - 0
README.md

@@ -57,3 +57,48 @@ Types of specifications include:
 - `bytesig`: confirms the presence of a byte sequence
 - `vfn`: takes a virtual method symbol and gets the vtable index for it on Linux and Windows
 (guesstimate on the latter)
+
+### Constraint file
+
+> [!IMPORTANT]
+> The constraint file specification is not finalized.  Please ensure that you are reading the
+> documentation for the version of `smgdc` that you are working with.
+
+Linux binaries are sometimes compiled with flags that cause a many-to-one mapping of symbols to
+function addresses (in other words: link time optimization).  While the application does attempt
+to uniquely identify symbols using other markers, there's sometimes insufficient information to
+do so.
+
+A constraint file is used as a last resort for end users to manually identify which symbols map
+to which positions for a given virtual table and its subclasses.  This is done by specifying
+relative ordering constraints for each symbol, allowing for reuse between binary revisions
+(with the assumption that vtables aren't reordered across them).
+
+Example:
+
+```
+[[_ZTV17CBaseCombatWeapon]]
+constraint = "soft"
+symbols = [
+	"_ZN17CBaseCombatWeapon6DeleteEv",
+	"_ZN17CBaseCombatWeapon4KillEv",
+]
+
+[[_ZTV17CBaseCombatWeapon]]
+constraint = "consecutive"
+symbols = [
+	"_ZN17CBaseCombatWeapon27WeaponRangeAttack1ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponRangeAttack2ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponMeleeAttack1ConditionEff",
+	"_ZN17CBaseCombatWeapon27WeaponMeleeAttack2ConditionEff",
+]
+```
+
+In this `CBaseCombatWeapon` vtable, a *soft order* constraint is applied such that the offset
+of `CBaseCombatWeapon::Delete` is lower than `CBaseCombatWeapon::Kill` — both symbols
+point to the same address, and both are present at two offsets, so this ensures that once
+solved, each symbol will be assigned the correct offset.
+
+A *consecutive order* constraint is also applied such that
+`CBaseCombatWeapon::WeaponRangeAttack1Condition` and subsequent symbols have monotonically
+increasing offsets.

+ 80 - 20
src/smgdc/angr/vtable_disamb.py

@@ -7,7 +7,9 @@ import itertools
 import operator
 from typing import Iterable, NamedTuple
 
+import claripy
 import itanium_demangler as demangler
+import msgspec
 from cle.backends.symbol import Symbol
 from cle.memory import Clemory
 
@@ -79,9 +81,40 @@ VTable = list[Symbol]
 
 # some virtual functions cannot be disambiguated outright, or need to check the resolved
 # state of the rest of the table - for the former we have to disambiguate them out-of-band
-#
-# ideally we would do some linear constraint on ordering, but for now we just spec the offset
-VTableConstraintDict = dict[str, dict[str, int]]
+class BaseConstraint(msgspec.Struct, tag_field="constraint"):
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        raise NotImplementedError("BaseConstraint.apply is abstract")
+
+
+class SoftOrderConstraint(BaseConstraint, tag="soft"):
+    # a constraint where symbols must be in ascending order but gaps may be present
+    symbols: list[str]
+
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        if not all(x in symtable for x in self.symbols):
+            return
+        solver.add(
+            claripy.And(
+                *(symtable[b] > symtable[a] for a, b in itertools.pairwise(self.symbols))
+            )
+        )
+
+
+class ConsecutiveOrderConstraint(BaseConstraint, tag="consecutive"):
+    # a constraint where symbols must have offsets one higher than the previous
+    symbols: list[str]
+
+    def apply(self, solver: claripy.Solver, symtable: dict[str, claripy.ast.bv.BV]) -> None:
+        if not all(x in symtable for x in self.symbols):
+            return
+        solver.add(
+            claripy.And(
+                *(symtable[b] == symtable[a] + 1 for a, b in itertools.pairwise(self.symbols))
+            )
+        )
+
+
+VTableConstraintDict = dict[str, list[SoftOrderConstraint | ConsecutiveOrderConstraint]]
 
 
 class VtableDisambiguator(angr.Analysis):
@@ -321,18 +354,6 @@ class VtableDisambiguator(angr.Analysis):
 
         return vt_out
 
-    def get_constrained_vfn(
-        self, vt: Symbol, vfnidx: int, candidate_fnsyms: set[Symbol]
-    ) -> Symbol | None:
-        for psym in self.superclass_map[vt]:
-            if psym.name not in self.vtable_constraint:
-                continue
-            vt_const = self.vtable_constraint[psym.name]
-            for fnsym in candidate_fnsyms:
-                if fnsym.name in vt_const and vt_const[fnsym.name] == vfnidx:
-                    return fnsym
-        return None
-
     def get_vfn_from_parent(self, vt: Symbol, vtidx: int) -> int:
         # HACK: some virtual destructors got optimized out and are represented by nullptrs
         #       in that case we try to match functions from the parent
@@ -369,11 +390,6 @@ class VtableDisambiguator(angr.Analysis):
                 fnsyms = set(self.syms_by_addr.get(vptr) or set())
 
                 if len(fnsyms) > 1:
-                    constrained_sym = self.get_constrained_vfn(vt, n, fnsyms)
-                    if constrained_sym:
-                        function_list.append(VTableFunction(table_index, {constrained_sym}))
-                        continue
-
                     # function in vtable is referenced by multiple names; perform disambiguation
                     matched_overload = self.resolve_ambiguous_vfn(
                         n, fnsyms, self.get_possible_vtable_set_candidates(vt, n)
@@ -405,6 +421,50 @@ class VtableDisambiguator(angr.Analysis):
                 )
                 vfn.possible_syms -= disambiguated_functions
 
+        if self.vtable_constraint:
+            solver = claripy.Solver()
+            solver_asts = {}
+            positional_constraints = collections.defaultdict(set)
+            for n, vfn in enumerate(function_list):
+                if len(vfn.possible_syms) == 1:
+                    sym = set(vfn.possible_syms).pop()
+                    # symbolic names for constant values - we don't use these ourselves; they're here
+                    # so user-defined tooling can specify them relative to variables as needed
+                    solver_asts[sym.name] = claripy.BVV(n, 10)
+                    continue
+
+                for sym in vfn.possible_syms:
+                    # constrain variable values to those defined as ambiguous
+                    positional_constraints[sym].add(n)
+
+            for sym, positions in positional_constraints.items():
+                if sym.name not in solver_asts:
+                    solver_asts[sym.name] = claripy.BVS(sym.name, 10)
+                solver.add(claripy.Or(*(solver_asts[sym.name] == n for n in positions)))
+
+            for pvt in self.superclass_map[vt]:
+                constraints = self.vtable_constraint.get(pvt.name)
+                if not constraints:
+                    continue
+                for constraint in constraints:
+                    constraint.apply(solver, solver_asts)
+
+            solutions = solver.batch_eval(solver_asts.values(), 2)
+            if len(solutions) == 1:
+                unique_solution = {v: k for k, v in zip(solver_asts.keys(), solutions[0])}
+                for n, vfn in enumerate(function_list):
+                    if (
+                        len(vfn.possible_syms) > 1
+                        and unique_solution
+                        and n in unique_solution.keys()
+                    ):
+                        resolved_syms = {
+                            sym
+                            for sym in iter(vfn.possible_syms)
+                            if sym.name == unique_solution[n]
+                        }
+                        vfn.possible_syms = resolved_syms
+
         # final uniqueness pass
         # raise exception if we still have multiple possible symbols for a given position
         for n, vfn in enumerate(function_list):