Browse Source

Add configurable virtual function 'constraints'

Helps disambiguate the last few deduplicated function addresses.
nosoop 10 months ago
parent
commit
137f30011e
2 changed files with 36 additions and 2 deletions
  1. 18 2
      src/smgdc/validate.py
  2. 18 0
      src/smgdc/vtable.py

+ 18 - 2
src/smgdc/validate.py

@@ -48,10 +48,17 @@ def convert_types(*types):
 # or bytesigs + offsets
 ResultValues = dict[string.Template, typing.Any]
 
+# some virtual functions cannot be disambiguated outright, or need to check the resolved
+# state of the rest of the table - for the former we have to disambiguate them out-of-band
+#
+# ideally we would do some linear constraint on ordering, but for now we just spec the offset
+VTableConstraintDict = dict[str, dict[str, int]]
+
 
 class BaseBinary:
     path: pathlib.Path
     angr: angr.Project
+    hash: str
     _file: io.IOBase
     _mm: mmap.mmap
 
@@ -59,8 +66,8 @@ class BaseBinary:
         self.path = path
         self._file = open(self.path, "rb")
 
-        file_hash = hashlib.file_digest(self._file, "sha256")
-        cached_proj = (cache_path or pathlib.Path()) / f"{file_hash.hexdigest()}.angr.pkl"
+        self.hash = hashlib.file_digest(self._file, "sha256").hexdigest()
+        cached_proj = (cache_path or pathlib.Path()) / f"{self.hash}.angr.pkl"
         if not cached_proj.exists():
             self.angr = angr.Project(self.path, load_options={"auto_load_libs": False})
             cached_proj.write_bytes(pickle.dumps(self.angr))
@@ -83,8 +90,17 @@ class WindowsBinary(BaseBinary):
 
 
 class LinuxBinary(BaseBinary):
+    vtable_constraint: VTableConstraintDict
+
     def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
         super().__init__(path, cache_path)
+        self.vtable_constraint = {}
+
+        constraints_file = (cache_path or pathlib.Path()) / f"{self.hash}.constraints.toml"
+        if constraints_file.exists():
+            self.vtable_constraint = msgspec.toml.decode(
+                constraints_file.read_bytes(), type=VTableConstraintDict
+            )
 
     @functools.cached_property
     def vtable_disambiguator(self) -> VtableDisambiguator:

+ 18 - 0
src/smgdc/vtable.py

@@ -88,6 +88,19 @@ def get_windows_vtables_from(bin: "LinuxBinary", vt: Symbol) -> VTable:
     return vt_out
 
 
+def get_constrained_vfn(
+    bin: "LinuxBinary", vt: Symbol, vfnidx: int, candidate_fnsyms: set[Symbol]
+) -> Symbol | None:
+    for psym in bin.vtable_disambiguator.superclass_map[vt]:
+        if psym.name not in bin.vtable_constraint:
+            continue
+        vt_const = bin.vtable_constraint[psym.name]
+        for fnsym in candidate_fnsyms:
+            if fnsym.name in vt_const and vt_const[fnsym.name] == vfnidx:
+                return fnsym
+    return None
+
+
 def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
     vtda = bin.vtable_disambiguator
 
@@ -107,6 +120,11 @@ def get_vtables_from_address(bin: "LinuxBinary", vt: Symbol) -> list[VTable]:
                 function_list.append(VTableFunction(table_index, fnsyms))
                 continue
             elif len(fnsyms) > 1:
+                constrained_sym = get_constrained_vfn(bin, vt, n, fnsyms)
+                if constrained_sym:
+                    function_list.append(VTableFunction(table_index, {constrained_sym}))
+                    continue
+
                 # function in vtable is referenced by multiple names; perform disambiguation
                 matched_overload = None
                 for related in vtda.get_possible_vtable_set_candidates(vt, n):