Selaa lähdekoodia

Add helpers to extract vtable typename and function name

nosoop 10 kuukautta sitten
vanhempi
commit
a8a1b24683
4 muutettua tiedostoa jossa 129 lisäystä ja 1 poistoa
  1. 3 0
      .justfile
  2. 2 1
      pyproject.toml
  3. 32 0
      src/smgdc/demangler_helpers.py
  4. 92 0
      tests/test_demangler.py

+ 3 - 0
.justfile

@@ -1,7 +1,10 @@
 test:
   ruff check src/smgdc
+  pytest tests
   mypy -p src
 
 format:
   ruff check src --select I001 --fix
   ruff format src/smgdc
+  ruff check tests --select I001 --fix
+  ruff format tests

+ 2 - 1
pyproject.toml

@@ -17,7 +17,8 @@ smgdc = "smgdc.app:main"
 [project.optional-dependencies]
 dev = [
     "mypy == 1.9.0",
-    "ruff == 0.4.0"
+    "pytest == 8.2.0",
+    "ruff == 0.4.0",
 ]
 
 [build-system]

+ 32 - 0
src/smgdc/demangler_helpers.py

@@ -21,6 +21,38 @@ def extract_method_classname(node) -> tuple[demangler.Node, ...]:
     raise ValueError(f"{node} is not a function")
 
 
+def extract_function_name(node) -> tuple[demangler.Node, ...]:
+    # this name will include the class at the start
+    def _extract(node: demangler.Node):
+        match node:
+            case node if node.kind in ("name", "cv_qual"):
+                return node.value
+            case node if node.kind == "qual_name":
+                return node.value
+        raise ValueError(f"Unexpected node {node!r}")
+
+    if node.kind == "func":
+        return _extract(node.name)
+
+    raise ValueError(f"{node} is not a function")
+
+
+def extract_vtable_typename(node):
+    # extracts the typename of a vtable, guaranteed as a tuple of nodes
+    def _extract(node):
+        match node:
+            case node if node.kind == "qual_name":
+                return node.value
+            case node if node.kind == "name":
+                return (node,)
+        raise ValueError(f"Unexpected node {node!r}")
+
+    if node.kind == "vtable":
+        return _extract(node.value)
+
+    raise ValueError(f"{node} is not a vtable")
+
+
 def extract_method_fname(node: demangler.Node) -> tuple[demangler.Node, ...]:
     # returns method name with any associated qualifiers
     def _extract(node: demangler.Node):

+ 92 - 0
tests/test_demangler.py

@@ -0,0 +1,92 @@
+#!/usr/bin/python3
+
+import itanium_demangler as dm
+import pytest
+import smgdc.demangler_helpers as dh
+
+
+@pytest.mark.parametrize(
+    "input_vtsym,input_fnsym,expected",
+    [
+        (
+            # PASS: template class
+            "_ZTV14CEntityFactoryI10CGunTargetE",
+            "_ZN14CEntityFactoryI10CGunTargetE6CreateEPKc",
+            True,
+        ),
+        (
+            # PASS: namespaced class
+            "_ZTVN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollectorE",
+            "_ZN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollector8AddErrorEiiRKSs",
+            True,
+        ),
+        (
+            # FAIL: unrelated classes
+            "_ZTV13CTFBaseRocket",
+            "_Z13GetScriptDescI11CBasePlayerEP17ScriptClassDesc_tPT_",
+            False,
+        ),
+        (
+            # FAIL: class in partial matching namespace
+            "_ZTVN6google8protobuf10TextFormat17FieldValuePrinterE",
+            "_ZN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollector10AddWarningEiiRKSs",
+            False,
+        ),
+        (
+            # FAIL: specialized template and class implementation
+            "_ZTV14CEntityFactoryI19CTFProjectile_FlareE",
+            "_ZN19CTFProjectile_Flare14GetDataDescMapEv",
+            False,
+        ),
+    ],
+)
+def test_function_in_virtual_class(input_vtsym: str, input_fnsym: str, expected: bool):
+    """
+    Checks if the function is for the vtable's class.
+    Note that, obviously, this does not check that the function is actually in the vtable.
+    """
+    node_vtsym, node_fnsym = map(dm.parse, (input_vtsym, input_fnsym))
+    vtable_typename = dh.extract_vtable_typename(node_vtsym)
+    function_qualname = dh.extract_function_name(node_fnsym)
+
+    assert (vtable_typename == function_qualname[: len(vtable_typename)]) == expected
+
+
+@pytest.mark.parametrize(
+    "fnsym_a,fnsym_b,expected",
+    [
+        (
+            # PASS: differing class, same name, same parameter type(s)
+            "_ZN11CBaseEntity10ChangeTeamEi",
+            "_ZN9CTFPlayer10ChangeTeamEi",
+            True,
+        ),
+        (
+            # PASS: class and a non-virtual thunk
+            "_ZN9CTFPlayer16GetAttributeListEv",
+            "_ZThn4764_N9CTFPlayer16GetAttributeListEv",
+            True,
+        ),
+        (
+            # FAIL: same class, same name, differing parameter type(s)
+            "_ZN9CTFPlayer10ChangeTeamEi",
+            "_ZN9CTFPlayer10ChangeTeamEibbb",
+            False,
+        ),
+        (
+            # FAIL: same class, differing name, same parameter type(s)
+            "_ZN12CTFGameRules27TrackWorkshopMapsInMapCycleEv",
+            "_ZN12CTFGameRules16LoadMapCycleFileEv",
+            False,
+        ),
+        (
+            # FAIL: specialized static methods with the same template (DataMapInit<T>)
+            "_Z11DataMapInitI14SoundCommand_tEP9datamap_tPT_",
+            "_Z11DataMapInitI10CKothLogicEP9datamap_tPT_",
+            False,
+        ),
+    ],
+)
+def test_function_match_signature(fnsym_a: str, fnsym_b: str, expected: bool):
+    sig_a, sig_b = map(dh.extract_method_signature, map(dm.parse, (fnsym_a, fnsym_b)))
+    assert expected == (sig_a == sig_b)