Przeglądaj źródła

Enforce type annotations

nosoop 10 miesięcy temu
rodzic
commit
ccd20751a7

+ 1 - 0
pyproject.toml

@@ -29,6 +29,7 @@ requires = [
 
 [tool.ruff]
 line-length = 96
+lint.extend-select = ["ANN001"]
 
 [tool.mypy]
 disable_error_code = ["import-untyped"]

+ 6 - 2
src/smgdc/angr/vtable_disamb.py

@@ -14,7 +14,9 @@ import angr
 from .. import demangler_helpers as dh
 
 
-def read_cstring(mem: Clemory, addr: int, encoding="utf-8", chunk=16, **kwargs) -> str:
+def read_cstring(
+    mem: Clemory, addr: int, encoding: str = "utf-8", chunk: int = 16, **kwargs
+) -> str:
     """unpacks a variable-length, zero-terminated string from the binary"""
     try:
         unpacker = itertools.takewhile(
@@ -159,7 +161,9 @@ class VtableDisambiguator(angr.Analysis):
                     return fnsym
         return None
 
-    def get_possible_vtable_set_candidates(self, vtsym, vtidx) -> Iterable[set[Symbol]]:
+    def get_possible_vtable_set_candidates(
+        self, vtsym: Symbol, vtidx: int
+    ) -> Iterable[set[Symbol]]:
         # Yield the ordered list of subclasses for the given class.
         yield self.subclass_map[vtsym]
 

+ 4 - 4
src/smgdc/demangler_helpers.py

@@ -8,7 +8,7 @@
 import itanium_demangler as demangler
 
 
-def extract_method_classname(node) -> tuple[demangler.Node, ...]:
+def extract_method_classname(node: demangler.Node) -> tuple[demangler.Node, ...]:
     def _extract(node: demangler.Node):
         match node:
             case node if node.kind == "qual_name":
@@ -21,7 +21,7 @@ def extract_method_classname(node) -> tuple[demangler.Node, ...]:
     raise ValueError(f"{node} is not a function")
 
 
-def extract_function_name(node) -> tuple[demangler.Node, ...]:
+def extract_function_name(node: demangler.Node) -> tuple[demangler.Node, ...]:
     # this name will include the class at the start
     def _extract(node: demangler.Node):
         match node:
@@ -39,9 +39,9 @@ def extract_function_name(node) -> tuple[demangler.Node, ...]:
     raise ValueError(f"{node} is not a function")
 
 
-def extract_vtable_typename(node):
+def extract_vtable_typename(node: demangler.Node):
     # extracts the typename of a vtable, guaranteed as a tuple of nodes
-    def _extract(node):
+    def _extract(node: demangler.Node):
         match node:
             case node if node.kind == "qual_name":
                 return node.value

+ 2 - 2
src/smgdc/types.py

@@ -12,7 +12,7 @@ class ByteSignature:
     pattern: list[str]
     expr: re.Pattern
 
-    def __init__(self, pattern):
+    def __init__(self, pattern: str):
         self.pattern = pattern.split()
 
         # creates escaped byte pattern with hex literals or wildcard as appropriate
@@ -48,7 +48,7 @@ class Code:
 
     code_ast: ast.AST
 
-    def __init__(self, v):
+    def __init__(self, v: str):
         self.code_ast = ast.parse(v, mode="eval")
 
     def eval(self, **locals):

+ 3 - 3
src/smgdc/validate.py

@@ -23,7 +23,7 @@ from .types import ByteSignature, Code, IntLiteral
 KEY_AS_IS = string.Template("${name}")
 
 
-def KEY_SUFFIX(s):
+def KEY_SUFFIX(s: str):
     return string.Template(f"${{name}} [{s}]")
 
 
@@ -72,7 +72,7 @@ class BaseBinary:
     def mmap(self):
         yield self._mm
 
-    def read(self, address, size) -> bytes:
+    def read(self, address: int, size: int) -> bytes:
         # shorthand to read a value from a physical file
         return self._mm[address : address + size]
 
@@ -99,7 +99,7 @@ class NumericOutputFormat(enum.StrEnum):
     HEX = "hex"
     HEX_SUFFIX = "hex_suffix"
 
-    def format_value(self, value) -> str:
+    def format_value(self, value: int) -> str:
         if self == NumericOutputFormat.HEX:
             return hex(value)
         elif self == NumericOutputFormat.HEX_SUFFIX:

+ 1 - 1
src/smgdc/vtable.py

@@ -22,7 +22,7 @@ VTable = list[Symbol]
 demangler._is_ctor_or_dtor = demangler.is_ctor_or_dtor
 
 
-def reorder_vfns_windows_estimate(symbols: list[Symbol], start_pos) -> list[Symbol]:
+def reorder_vfns_windows_estimate(symbols: list[Symbol], start_pos: int) -> list[Symbol]:
     # reorders a given subclass-level slice of linux symbols to reflect windows ordering
     name_buckets = collections.defaultdict(list)
     for n, symbol in enumerate(symbols):