#!/usr/bin/python3 import itertools from typing import Iterable import itanium_demangler as dm import pytest import smgdc.demangler_helpers as dh @pytest.mark.parametrize( "input_vtsym,input_fnsym,expected", [ ( # PASS: template class "_ZTV14CEntityFactoryI10CGunTargetE", "_ZN14CEntityFactoryI10CGunTargetE6CreateEPKc", True, ), ( # PASS: namespaced class "_ZTVN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollectorE", "_ZN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollector8AddErrorEiiRKSs", True, ), ( # PASS: qualified name in function "_ZTV13CTFWeaponBase", "_ZNK13CTFWeaponBase11GetWeaponIDEv", True, ), ( # FAIL: unrelated classes "_ZTV13CTFBaseRocket", "_Z13GetScriptDescI11CBasePlayerEP17ScriptClassDesc_tPT_", False, ), ( # FAIL: class in partial matching namespace "_ZTVN6google8protobuf10TextFormat17FieldValuePrinterE", "_ZN6google8protobuf10TextFormat6Parser10ParserImpl20ParserErrorCollector10AddWarningEiiRKSs", False, ), ( # FAIL: specialized template and class implementation "_ZTV14CEntityFactoryI19CTFProjectile_FlareE", "_ZN19CTFProjectile_Flare14GetDataDescMapEv", False, ), ], ) def test_function_in_virtual_class(input_vtsym: str, input_fnsym: str, expected: bool): """ Checks if the function is for the vtable's class. Note that, obviously, this does not check that the function is actually in the vtable. """ node_vtsym, node_fnsym = map(dm.parse, (input_vtsym, input_fnsym)) vtable_typename = dh.extract_vtable_typename(node_vtsym) function_qualname = dh.extract_function_name(node_fnsym) assert (vtable_typename == function_qualname[: len(vtable_typename)]) == expected @pytest.mark.parametrize( "fnsym_a,fnsym_b,expected", [ ( # PASS: differing class, same name, same parameter type(s) "_ZN11CBaseEntity10ChangeTeamEi", "_ZN9CTFPlayer10ChangeTeamEi", True, ), ( # PASS: class and a non-virtual thunk "_ZN9CTFPlayer16GetAttributeListEv", "_ZThn4764_N9CTFPlayer16GetAttributeListEv", True, ), ( # FAIL: same class, same name, differing parameter type(s) "_ZN9CTFPlayer10ChangeTeamEi", "_ZN9CTFPlayer10ChangeTeamEibbb", False, ), ( # FAIL: same class, differing name, same parameter type(s) "_ZN12CTFGameRules27TrackWorkshopMapsInMapCycleEv", "_ZN12CTFGameRules16LoadMapCycleFileEv", False, ), ( # FAIL: specialized static methods with the same template (DataMapInit) "_Z11DataMapInitI14SoundCommand_tEP9datamap_tPT_", "_Z11DataMapInitI10CKothLogicEP9datamap_tPT_", False, ), ], ) def test_function_match_signature(fnsym_a: str, fnsym_b: str, expected: bool): sig_a, sig_b = map(dh.extract_method_signature, map(dm.parse, (fnsym_a, fnsym_b))) assert expected == (sig_a == sig_b) @pytest.mark.parametrize( "fnsyms,expected", [ ( # PASS: same class, same name, differing parameter types ( "_ZN11CBaseEntity8KeyValueEPKcS1_", "_ZN11CBaseEntity8KeyValueEPKcf", "_ZN11CBaseEntity8KeyValueEPKcRK6Vector", ), True, ), ( # PASS: same class, same name, differing parameter types ( "_ZN24CTFWeaponBaseGrenadeProj11InitGrenadeERK6VectorS2_P20CBaseCombatCharacterRK13CTFWeaponInfo", "_ZN24CTFWeaponBaseGrenadeProj11InitGrenadeERK6VectorS2_P20CBaseCombatCharacterif", ), True, ), ( # PASS: different class, same name, differing parameter types ( "_ZN20CBaseCombatCharacter8FVisibleEP11CBaseEntityiPS1_", "_ZN11CBaseEntity8FVisibleERK6VectoriPPS_", ), True, ), ( # PASS: same class, one is const-qualified ( "_ZN11CBaseObject13CanBeUpgradedEP9CTFPlayer", "_ZNK11CBaseObject13CanBeUpgradedEv", ), True, ), ( # FAIL: static function and class method ( "_Z19ReloadSceneFromDiskP11CBaseEntity", "_ZN18CServerChoreoTools19ReloadSceneFromDiskEi", ), False, ), ], ) def test_function_overloaded(fnsyms: Iterable[str], expected: bool): """ Tests that all the given symbols are overloads of the same name. Of course, this is only inference based on symbol name. """ node_syms = map(dh.extract_method_fname, map(dm.parse, fnsyms)) for a, b in itertools.permutations(node_syms, 2): assert a == b @pytest.mark.parametrize( "sym,expected", [ ( "_ZN12CTFGameRulesD0Ev", "deleting", ), ( "_ZN12CTFGameRulesD2Ev", "base", ), ], ) def test_function_dtor(sym: str, expected: str): assert dh.get_dtor_type(dm.parse(sym)) == expected