Browse Source

Strongly enforce type annotations

nosoop 10 months ago
parent
commit
48fa8ed5e7
4 changed files with 19 additions and 18 deletions
  1. 1 1
      pyproject.toml
  2. 5 5
      src/smgdc/demangler_helpers.py
  3. 6 5
      src/smgdc/types.py
  4. 7 7
      src/smgdc/validate.py

+ 1 - 1
pyproject.toml

@@ -29,7 +29,7 @@ requires = [
 
 [tool.ruff]
 line-length = 96
-lint.extend-select = ["ANN001"]
+lint.extend-select = ["ANN001", "ANN201", "ANN202"]
 
 [tool.mypy]
 disable_error_code = ["import-untyped"]

+ 5 - 5
src/smgdc/demangler_helpers.py

@@ -39,7 +39,7 @@ def get_dtor_type(node: demangler.Node) -> bool:
 
 
 def extract_method_classname(node: demangler.Node) -> tuple[demangler.Node, ...]:
-    def _extract(node: demangler.Node):
+    def _extract(node: demangler.Node) -> tuple[demangler.Node, ...]:
         match node:
             case node if node.kind == "qual_name":
                 return node.value[:-1]
@@ -53,7 +53,7 @@ def extract_method_classname(node: demangler.Node) -> tuple[demangler.Node, ...]
 
 def extract_function_name(node: demangler.Node) -> tuple[demangler.Node, ...]:
     # this name will include the class at the start
-    def _extract(node: demangler.Node):
+    def _extract(node: demangler.Node) -> tuple[demangler.Node, ...]:
         match node:
             case node if node.kind in "name":
                 return node.value
@@ -69,9 +69,9 @@ def extract_function_name(node: demangler.Node) -> tuple[demangler.Node, ...]:
     raise ValueError(f"{node} is not a function")
 
 
-def extract_vtable_typename(node: demangler.Node):
+def extract_vtable_typename(node: demangler.Node) -> tuple[demangler.Node, ...]:
     # extracts the typename of a vtable, guaranteed as a tuple of nodes
-    def _extract(node: demangler.Node):
+    def _extract(node: demangler.Node) -> tuple[demangler.Node, ...]:
         match node:
             case node if node.kind == "qual_name":
                 return node.value
@@ -87,7 +87,7 @@ def extract_vtable_typename(node: demangler.Node):
 
 def extract_method_fname(node: demangler.Node) -> tuple[demangler.Node, ...]:
     # returns method name with any associated qualifiers
-    def _extract(node: demangler.Node):
+    def _extract(node: demangler.Node) -> tuple[demangler.Node, ...]:
         match node:
             case node if node.kind == "cv_qual":
                 return _extract(node.value)

+ 6 - 5
src/smgdc/types.py

@@ -2,6 +2,7 @@
 
 import ast
 import re
+import typing
 
 
 class ByteSequence:
@@ -38,17 +39,17 @@ class ByteSignature:
         )
 
     @property
-    def display_str(self):
+    def display_str(self) -> str:
         # render as a bracketed, space-delimited hex string where wildcard bytes are denoted with '??'
         return f"[{' '.join(self.pattern)}]".lower()
 
     @property
-    def gameconf_str(self):
+    def gameconf_str(self) -> str:
         # render as a SourceMod-style escaped string
         return "".join(r"\x" + b.upper() if b != "??" else r"\x2A" for b in self.pattern)
 
     @property
-    def length(self):
+    def length(self) -> int:
         return len(self.pattern)
 
     def __repr__(self):
@@ -60,12 +61,12 @@ class Code:
     A class to parse and execute Python code in a reduced (but not secure) environment.
     """
 
-    code_ast: ast.AST
+    code_ast: ast.Expression
 
     def __init__(self, v: str):
         self.code_ast = ast.parse(v, mode="eval")
 
-    def eval(self, **locals):
+    def eval(self, **locals) -> typing.Any:
         code = compile(self.code_ast, "<CONFIG>", mode="eval")
         return eval(code, None, locals)
 

+ 7 - 7
src/smgdc/validate.py

@@ -24,7 +24,7 @@ from .types import ByteSequence, ByteSignature, Code, IntLiteral
 KEY_AS_IS = string.Template("${name}")
 
 
-def KEY_SUFFIX(s: str):
+def KEY_SUFFIX(s: str) -> string.Template:
     return string.Template(f"${{name}} [{s}]")
 
 
@@ -36,7 +36,7 @@ eval_functions = {
 }
 
 
-def convert_types(*types):
+def convert_types(*types) -> typing.Callable[[typing.Type, typing.Any], typing.Any]:
     def _dec_hook(type: typing.Type, obj: typing.Any) -> typing.Any:
         if type in types:
             return type(obj)
@@ -75,7 +75,7 @@ class BaseBinary:
         self._mm = mmap.mmap(self._file.fileno(), 0, access=mmap.ACCESS_READ)
 
     @contextlib.contextmanager
-    def mmap(self):
+    def mmap(self) -> typing.Generator[mmap.mmap, None, None]:
         yield self._mm
 
     def read(self, address: int, size: int) -> bytes:
@@ -180,9 +180,9 @@ class VirtualFunctionEntry(BaseEntry, tag="vfn"):
         raise ValueError("Missing vfn?")
 
     @property
-    def typename_from_symbol(self):
+    def typename_from_symbol(self) -> str:
         if not self.symbol.startswith("_ZN"):
-            return
+            return ""
         start_range = 3
         if self.symbol.startswith("_ZNK"):
             start_range = 4
@@ -250,11 +250,11 @@ class ByteSigEntry(LocationEntry, tag="bytesig", kw_only=True):
 
         with bin.mmap() as memory:
             matches = self.contents.expr.finditer(memory)
-            match = next(matches, False)
+            match = next(matches, None)
             if not match:
                 # no matches found at all, fail validation
                 raise AssertionError(f"No matches found for {self.contents.display_str}")
-            if not self.allow_multiple and next(matches, False):
+            if not self.allow_multiple and next(matches, None) is not None:
                 # non-unique byte pattern, fail validation
                 raise AssertionError(f"Multiple matches found for {self.contents.display_str}")
             return outputs