validate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #!/usr/bin/python3
  2. import configparser
  3. import contextlib
  4. import enum
  5. import functools
  6. import hashlib
  7. import io
  8. import itertools
  9. import mmap
  10. import pathlib
  11. import pickle
  12. import string
  13. import struct
  14. import typing
  15. import angr
  16. import msgspec
  17. from . import vtable as vt_helpers
  18. from .angr.vtable_disamb import VtableDisambiguator
  19. from .types import ByteSignature, Code, IntLiteral
  20. KEY_AS_IS = string.Template("${name}")
  21. def KEY_SUFFIX(s: str):
  22. return string.Template(f"${{name}} [{s}]")
  23. # collection of functions that can be used during value read operations
  24. eval_functions = {
  25. # truncate the given value to the given number of bits
  26. # https://stackoverflow.com/a/53424236
  27. "truncate": lambda val, num_bits: val & (2**num_bits - 1),
  28. }
  29. def convert_types(*types):
  30. def _dec_hook(type: typing.Type, obj: typing.Any) -> typing.Any:
  31. if type in types:
  32. return type(obj)
  33. raise NotImplementedError
  34. return _dec_hook
  35. # entries may output multiple values, such as separate Windows / Linux vtable indices
  36. # or bytesigs + offsets
  37. ResultValues = dict[string.Template, typing.Any]
  38. # some virtual functions cannot be disambiguated outright, or need to check the resolved
  39. # state of the rest of the table - for the former we have to disambiguate them out-of-band
  40. #
  41. # ideally we would do some linear constraint on ordering, but for now we just spec the offset
  42. VTableConstraintDict = dict[str, dict[str, int]]
  43. class BaseBinary:
  44. path: pathlib.Path
  45. angr: angr.Project
  46. hash: str
  47. _file: io.IOBase
  48. _mm: mmap.mmap
  49. def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
  50. self.path = path
  51. self._file = open(self.path, "rb")
  52. self.hash = hashlib.file_digest(self._file, "sha256").hexdigest()
  53. cached_proj = (cache_path or pathlib.Path()) / f"{self.hash}.angr.pkl"
  54. if not cached_proj.exists():
  55. self.angr = angr.Project(self.path, load_options={"auto_load_libs": False})
  56. cached_proj.write_bytes(pickle.dumps(self.angr))
  57. else:
  58. self.angr = pickle.loads(cached_proj.read_bytes())
  59. self._mm = mmap.mmap(self._file.fileno(), 0, access=mmap.ACCESS_READ)
  60. @contextlib.contextmanager
  61. def mmap(self):
  62. yield self._mm
  63. def read(self, address: int, size: int) -> bytes:
  64. # shorthand to read a value from a physical file
  65. return self._mm[address : address + size]
  66. class WindowsBinary(BaseBinary):
  67. def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
  68. super().__init__(path, cache_path)
  69. class LinuxBinary(BaseBinary):
  70. vtable_constraint: VTableConstraintDict
  71. def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
  72. super().__init__(path, cache_path)
  73. self.vtable_constraint = {}
  74. constraints_file = (cache_path or pathlib.Path()) / f"{self.hash}.constraints.toml"
  75. if constraints_file.exists():
  76. self.vtable_constraint = msgspec.toml.decode(
  77. constraints_file.read_bytes(), type=VTableConstraintDict
  78. )
  79. @functools.cached_property
  80. def vtable_disambiguator(self) -> VtableDisambiguator:
  81. return self.angr.analyses.VtableDisambiguator()
  82. PlatformBinary = WindowsBinary | LinuxBinary
  83. class NumericOutputFormat(enum.StrEnum):
  84. INT = "int"
  85. HEX = "hex"
  86. HEX_SUFFIX = "hex_suffix"
  87. def format_value(self, value: int) -> str:
  88. if self == NumericOutputFormat.HEX:
  89. return hex(value)
  90. elif self == NumericOutputFormat.HEX_SUFFIX:
  91. return f"{value:X}h"
  92. elif self == NumericOutputFormat.INT:
  93. return f"{value}"
  94. raise NotImplementedError(f"Missing numeric output for {self}")
  95. class BaseEntry(msgspec.Struct, kw_only=True):
  96. # the partial path pointing to a binary
  97. target: pathlib.Path
  98. def process(self, bin: PlatformBinary) -> ResultValues:
  99. raise NotImplementedError(f"Cannot process {type(self).__qualname__}")
  100. def get_target_match(self, candidates: list[PlatformBinary]) -> PlatformBinary | None:
  101. for candidate in candidates:
  102. if candidate.path.absolute().parts[-len(self.target.parts) :] == self.target.parts:
  103. return candidate
  104. return None
  105. class LocationEntry(BaseEntry):
  106. symbol: str | None = None
  107. offset: IntLiteral = IntLiteral("0")
  108. bytescan: ByteSignature | None = None
  109. offset_fmt: NumericOutputFormat = NumericOutputFormat.INT
  110. def __post_init__(self):
  111. if self.bytescan:
  112. return
  113. if self.symbol:
  114. return
  115. raise ValueError("Missing location anchor (expected either 'bytescan' or 'symbol')")
  116. def calculate_phys_address(self, bin: PlatformBinary) -> int:
  117. # returns the physical offset within the file
  118. if self.bytescan:
  119. with bin.mmap() as memory:
  120. matches = self.bytescan.expr.finditer(memory)
  121. match = next(matches, None)
  122. if match:
  123. return match.start() + self.offset
  124. else:
  125. raise AssertionError(
  126. "No matches found for 'bytescan' value " f"{self.bytescan.display_str}"
  127. )
  128. sym = bin.angr.loader.find_symbol(self.symbol)
  129. if not sym:
  130. raise AssertionError("Could not find symbol {self.symbol}")
  131. offset = bin.angr.loader.main_object.addr_to_offset(sym.rebased_addr + self.offset)
  132. assert offset, "Received invalid {offset = }"
  133. return offset
  134. class VirtualFunctionEntry(BaseEntry, tag="vfn"):
  135. # linux-specific entry that takes a symbol and returns values for Windows / Linux
  136. symbol: str
  137. typename: str | None = msgspec.field(name="vtable", default=None)
  138. def __post_init__(self):
  139. raise ValueError("Missing vfn?")
  140. @property
  141. def typename_from_symbol(self):
  142. if not self.symbol.startswith("_ZN"):
  143. return
  144. start_range = 3
  145. if self.symbol.startswith("_ZNK"):
  146. start_range = 4
  147. # this only handles the simple case of a non-template classname
  148. int_prefix = "".join(itertools.takewhile(str.isdigit, self.symbol[start_range:]))
  149. chars_to_read = int(int_prefix)
  150. end_range = start_range + len(int_prefix) + chars_to_read
  151. if not self.symbol[end_range].isdigit():
  152. raise ValueError(f"Could not parse function symbol {self.symbol} into a type name")
  153. return self.symbol[start_range:end_range]
  154. def process(self, bin: PlatformBinary) -> ResultValues:
  155. # returns windows and linux vtable offsets
  156. # TODO: implement
  157. assert isinstance(
  158. bin, LinuxBinary
  159. ), "Expected Linux binary for virtual function handling"
  160. self.typename = self.typename or self.typename_from_symbol
  161. vtsym = bin.angr.loader.find_symbol(f"_ZTV{self.typename}")
  162. if not vtsym:
  163. raise ValueError(f"Could not find vtable symbol _ZTV{self.typename}")
  164. orig_vtable, *thunk_vtables = vt_helpers.get_vtables_from_address(bin, vtsym)
  165. win_vtable = vt_helpers.get_windows_vtables_from(bin, vtsym)
  166. sym = bin.angr.loader.find_symbol(self.symbol)
  167. return {
  168. KEY_SUFFIX("LINUX"): orig_vtable.index(sym) if sym in orig_vtable else None,
  169. KEY_SUFFIX("WINDOWS"): win_vtable.index(sym) if sym in win_vtable else None,
  170. }
  171. class ByteSigEntry(LocationEntry, tag="bytesig", kw_only=True):
  172. # value to be inserted into gameconf after asserting that the given location matches
  173. contents: ByteSignature
  174. # most bytesigs are expected to be unique; escape hatch for those that are just typecasted
  175. allow_multiple: bool = False
  176. def process(self, bin: PlatformBinary) -> ResultValues:
  177. if self.symbol or self.bytescan:
  178. address = self.calculate_phys_address(bin)
  179. data = bin.read(address, self.contents.length)
  180. if not self.contents.expr.match(data):
  181. actual_disp = f"[{data.hex(' ')}]"
  182. raise AssertionError(
  183. f"Assertion failed: {self.contents.display_str} != {actual_disp}"
  184. )
  185. return {
  186. KEY_AS_IS: self.contents.gameconf_str,
  187. KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
  188. }
  189. with bin.mmap() as memory:
  190. matches = self.contents.expr.finditer(memory)
  191. match = next(matches, False)
  192. if not match:
  193. # no matches found at all, fail validation
  194. raise AssertionError(f"No matches found for {self.contents.display_str}")
  195. if not self.allow_multiple and next(matches, False):
  196. # non-unique byte pattern, fail validation
  197. raise AssertionError(f"Multiple matches found for {self.contents.display_str}")
  198. return {
  199. KEY_AS_IS: self.contents.gameconf_str,
  200. KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
  201. }
  202. class ValueReadEntry(LocationEntry, tag="value", kw_only=True):
  203. # value to decode at a given symbol / offset
  204. struct: struct.Struct
  205. assert_stmt: Code | None = msgspec.field(default=None, name="assert")
  206. modify_stmt: Code | None = msgspec.field(default=None, name="modify")
  207. def process(self, bin: PlatformBinary) -> ResultValues:
  208. address = self.calculate_phys_address(bin)
  209. data = bin.read(address, self.struct.size)
  210. result, *_ = self.struct.unpack(data)
  211. if self.modify_stmt:
  212. result = self.modify_stmt.eval(value=result)
  213. # run assertion to ensure value is expected
  214. if self.assert_stmt and not self.assert_stmt.eval(value=result, **eval_functions):
  215. raise AssertionError(f"'{self.assert_stmt}' failed for value {result}")
  216. return {
  217. KEY_AS_IS: result,
  218. KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
  219. }
  220. ConfigEntry = typing.Union[VirtualFunctionEntry, ByteSigEntry, ValueReadEntry]
  221. GameConfDict = dict[str, ConfigEntry]
  222. def read_config(config: configparser.ConfigParser) -> GameConfDict:
  223. return msgspec.convert(
  224. {s: config[s] for s in config.sections()},
  225. type=GameConfDict,
  226. strict=False,
  227. dec_hook=convert_types(ByteSignature, Code, IntLiteral, struct.Struct, pathlib.Path),
  228. )