123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- #!/usr/bin/python3
- import configparser
- import contextlib
- import enum
- import functools
- import hashlib
- import io
- import itertools
- import mmap
- import pathlib
- import pickle
- import string
- import struct
- import typing
- import angr
- import msgspec
- from . import vtable as vt_helpers
- from .angr.vtable_disamb import VtableDisambiguator
- from .types import ByteSignature, Code, IntLiteral
- KEY_AS_IS = string.Template("${name}")
- def KEY_SUFFIX(s: str):
- return string.Template(f"${{name}} [{s}]")
- # collection of functions that can be used during value read operations
- eval_functions = {
- # truncate the given value to the given number of bits
- # https://stackoverflow.com/a/53424236
- "truncate": lambda val, num_bits: val & (2**num_bits - 1),
- }
- def convert_types(*types):
- def _dec_hook(type: typing.Type, obj: typing.Any) -> typing.Any:
- if type in types:
- return type(obj)
- raise NotImplementedError
- return _dec_hook
- # entries may output multiple values, such as separate Windows / Linux vtable indices
- # or bytesigs + offsets
- ResultValues = dict[string.Template, typing.Any]
- # some virtual functions cannot be disambiguated outright, or need to check the resolved
- # state of the rest of the table - for the former we have to disambiguate them out-of-band
- #
- # ideally we would do some linear constraint on ordering, but for now we just spec the offset
- VTableConstraintDict = dict[str, dict[str, int]]
- class BaseBinary:
- path: pathlib.Path
- angr: angr.Project
- hash: str
- _file: io.IOBase
- _mm: mmap.mmap
- def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
- self.path = path
- self._file = open(self.path, "rb")
- self.hash = hashlib.file_digest(self._file, "sha256").hexdigest()
- cached_proj = (cache_path or pathlib.Path()) / f"{self.hash}.angr.pkl"
- if not cached_proj.exists():
- self.angr = angr.Project(self.path, load_options={"auto_load_libs": False})
- cached_proj.write_bytes(pickle.dumps(self.angr))
- else:
- self.angr = pickle.loads(cached_proj.read_bytes())
- self._mm = mmap.mmap(self._file.fileno(), 0, access=mmap.ACCESS_READ)
- @contextlib.contextmanager
- def mmap(self):
- yield self._mm
- def read(self, address: int, size: int) -> bytes:
- # shorthand to read a value from a physical file
- return self._mm[address : address + size]
- class WindowsBinary(BaseBinary):
- def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
- super().__init__(path, cache_path)
- class LinuxBinary(BaseBinary):
- vtable_constraint: VTableConstraintDict
- def __init__(self, path: pathlib.Path, cache_path: pathlib.Path | None = None):
- super().__init__(path, cache_path)
- self.vtable_constraint = {}
- constraints_file = (cache_path or pathlib.Path()) / f"{self.hash}.constraints.toml"
- if constraints_file.exists():
- self.vtable_constraint = msgspec.toml.decode(
- constraints_file.read_bytes(), type=VTableConstraintDict
- )
- @functools.cached_property
- def vtable_disambiguator(self) -> VtableDisambiguator:
- return self.angr.analyses.VtableDisambiguator()
- PlatformBinary = WindowsBinary | LinuxBinary
- class NumericOutputFormat(enum.StrEnum):
- INT = "int"
- HEX = "hex"
- HEX_SUFFIX = "hex_suffix"
- def format_value(self, value: int) -> str:
- if self == NumericOutputFormat.HEX:
- return hex(value)
- elif self == NumericOutputFormat.HEX_SUFFIX:
- return f"{value:X}h"
- elif self == NumericOutputFormat.INT:
- return f"{value}"
- raise NotImplementedError(f"Missing numeric output for {self}")
- class BaseEntry(msgspec.Struct, kw_only=True):
- # the partial path pointing to a binary
- target: pathlib.Path
- def process(self, bin: PlatformBinary) -> ResultValues:
- raise NotImplementedError(f"Cannot process {type(self).__qualname__}")
- def get_target_match(self, candidates: list[PlatformBinary]) -> PlatformBinary | None:
- for candidate in candidates:
- if candidate.path.absolute().parts[-len(self.target.parts) :] == self.target.parts:
- return candidate
- return None
- class LocationEntry(BaseEntry):
- symbol: str | None = None
- offset: IntLiteral = IntLiteral("0")
- bytescan: ByteSignature | None = None
- offset_fmt: NumericOutputFormat = NumericOutputFormat.INT
- def __post_init__(self):
- if self.bytescan:
- return
- if self.symbol:
- return
- raise ValueError("Missing location anchor (expected either 'bytescan' or 'symbol')")
- def calculate_phys_address(self, bin: PlatformBinary) -> int:
- # returns the physical offset within the file
- if self.bytescan:
- with bin.mmap() as memory:
- matches = self.bytescan.expr.finditer(memory)
- match = next(matches, None)
- if match:
- return match.start() + self.offset
- else:
- raise AssertionError(
- "No matches found for 'bytescan' value " f"{self.bytescan.display_str}"
- )
- sym = bin.angr.loader.find_symbol(self.symbol)
- if not sym:
- raise AssertionError("Could not find symbol {self.symbol}")
- offset = bin.angr.loader.main_object.addr_to_offset(sym.rebased_addr + self.offset)
- assert offset, "Received invalid {offset = }"
- return offset
- class VirtualFunctionEntry(BaseEntry, tag="vfn"):
- # linux-specific entry that takes a symbol and returns values for Windows / Linux
- symbol: str
- typename: str | None = msgspec.field(name="vtable", default=None)
- def __post_init__(self):
- raise ValueError("Missing vfn?")
- @property
- def typename_from_symbol(self):
- if not self.symbol.startswith("_ZN"):
- return
- start_range = 3
- if self.symbol.startswith("_ZNK"):
- start_range = 4
- # this only handles the simple case of a non-template classname
- int_prefix = "".join(itertools.takewhile(str.isdigit, self.symbol[start_range:]))
- chars_to_read = int(int_prefix)
- end_range = start_range + len(int_prefix) + chars_to_read
- if not self.symbol[end_range].isdigit():
- raise ValueError(f"Could not parse function symbol {self.symbol} into a type name")
- return self.symbol[start_range:end_range]
- def process(self, bin: PlatformBinary) -> ResultValues:
- # returns windows and linux vtable offsets
- # TODO: implement
- assert isinstance(
- bin, LinuxBinary
- ), "Expected Linux binary for virtual function handling"
- self.typename = self.typename or self.typename_from_symbol
- vtsym = bin.angr.loader.find_symbol(f"_ZTV{self.typename}")
- if not vtsym:
- raise ValueError(f"Could not find vtable symbol _ZTV{self.typename}")
- orig_vtable, *thunk_vtables = vt_helpers.get_vtables_from_address(bin, vtsym)
- win_vtable = vt_helpers.get_windows_vtables_from(bin, vtsym)
- sym = bin.angr.loader.find_symbol(self.symbol)
- return {
- KEY_SUFFIX("LINUX"): orig_vtable.index(sym) if sym in orig_vtable else None,
- KEY_SUFFIX("WINDOWS"): win_vtable.index(sym) if sym in win_vtable else None,
- }
- class ByteSigEntry(LocationEntry, tag="bytesig", kw_only=True):
- # value to be inserted into gameconf after asserting that the given location matches
- contents: ByteSignature
- # most bytesigs are expected to be unique; escape hatch for those that are just typecasted
- allow_multiple: bool = False
- def process(self, bin: PlatformBinary) -> ResultValues:
- if self.symbol or self.bytescan:
- address = self.calculate_phys_address(bin)
- data = bin.read(address, self.contents.length)
- if not self.contents.expr.match(data):
- actual_disp = f"[{data.hex(' ')}]"
- raise AssertionError(
- f"Assertion failed: {self.contents.display_str} != {actual_disp}"
- )
- return {
- KEY_AS_IS: self.contents.gameconf_str,
- KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
- }
- with bin.mmap() as memory:
- matches = self.contents.expr.finditer(memory)
- match = next(matches, False)
- if not match:
- # no matches found at all, fail validation
- raise AssertionError(f"No matches found for {self.contents.display_str}")
- if not self.allow_multiple and next(matches, False):
- # non-unique byte pattern, fail validation
- raise AssertionError(f"Multiple matches found for {self.contents.display_str}")
- return {
- KEY_AS_IS: self.contents.gameconf_str,
- KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
- }
- class ValueReadEntry(LocationEntry, tag="value", kw_only=True):
- # value to decode at a given symbol / offset
- struct: struct.Struct
- assert_stmt: Code | None = msgspec.field(default=None, name="assert")
- modify_stmt: Code | None = msgspec.field(default=None, name="modify")
- def process(self, bin: PlatformBinary) -> ResultValues:
- address = self.calculate_phys_address(bin)
- data = bin.read(address, self.struct.size)
- result, *_ = self.struct.unpack(data)
- if self.modify_stmt:
- result = self.modify_stmt.eval(value=result)
- # run assertion to ensure value is expected
- if self.assert_stmt and not self.assert_stmt.eval(value=result, **eval_functions):
- raise AssertionError(f"'{self.assert_stmt}' failed for value {result}")
- return {
- KEY_AS_IS: result,
- KEY_SUFFIX("OFFSET"): self.offset_fmt.format_value(self.offset),
- }
- ConfigEntry = typing.Union[VirtualFunctionEntry, ByteSigEntry, ValueReadEntry]
- GameConfDict = dict[str, ConfigEntry]
- def read_config(config: configparser.ConfigParser) -> GameConfDict:
- return msgspec.convert(
- {s: config[s] for s in config.sections()},
- type=GameConfDict,
- strict=False,
- dec_hook=convert_types(ByteSignature, Code, IntLiteral, struct.Struct, pathlib.Path),
- )
|