11from dataclasses import dataclass , field
2- from typing import Any , cast , Optional
2+ from typing import Any , cast , Optional , Tuple
33from enum import Enum
4+ import more_itertools
45
56import algosdk .abi as sdk_abi
67
@@ -87,7 +88,7 @@ def is_never(self) -> bool:
8788 and self .delete_application == CallConfig .NEVER
8889 )
8990
90- def oc_under_call_config (self , call_config : CallConfig ) -> list [EnumInt ]:
91+ def _oc_under_call_config (self , call_config : CallConfig ) -> list [EnumInt ]:
9192 if not isinstance (call_config , CallConfig ):
9293 raise TealInputError (
9394 "generate condition based on OCMethodCallConfigs should be based on OCMethodConfig"
@@ -427,15 +428,21 @@ def add_method_handler(
427428 raise TealInputError (
428429 f"registered method { method_signature } is never executed"
429430 )
430- oc_create : list [EnumInt ] = call_configs .oc_under_call_config (CallConfig .CREATE )
431- oc_call : list [EnumInt ] = call_configs .oc_under_call_config (CallConfig .CALL )
432431 if method_signature in self .added_method_sig :
433432 raise TealInputError (f"re-registering method { method_signature } detected" )
434433 self .added_method_sig .add (method_signature )
435434
436435 wrapped = Router .wrap_handler (True , method_call )
437436
438- if any (str (OnComplete .ClearState ) == str (x ) for x in oc_create ):
437+ def partition (cc : CallConfig ) -> Tuple [bool , list [EnumInt ]]:
438+ (not_clear_states , clear_states ) = more_itertools .partition (
439+ lambda x : str (x ) == str (OnComplete .ClearState ),
440+ call_configs ._oc_under_call_config (cc ),
441+ )
442+ return (len (list (clear_states )) > 0 , list (not_clear_states ))
443+
444+ (create_has_clear_state , create_others ) = partition (CallConfig .CREATE )
445+ if create_has_clear_state :
439446 self .categorized_clear_state_ast .method_calls_create .append (
440447 CondNode (
441448 And (
@@ -445,35 +452,33 @@ def add_method_handler(
445452 wrapped ,
446453 )
447454 )
448- oc_create = [
449- oc for oc in oc_create if str (oc ) != str (OnComplete .ClearState )
450- ]
451- if any (str (OnComplete .ClearState ) == str (x ) for x in oc_call ):
455+
456+ (call_has_clear_state , call_others ) = partition (CallConfig .CALL )
457+ if call_has_clear_state :
452458 self .categorized_clear_state_ast .method_calls .append (
453459 CondNode (
454460 Txn .application_args [0 ] == MethodSignature (method_signature ),
455461 wrapped ,
456462 )
457463 )
458- oc_call = [oc for oc in oc_call if str (oc ) != str (OnComplete .ClearState )]
459464
460- if oc_create :
465+ if create_others :
461466 self .categorized_approval_ast .method_calls_create .append (
462467 CondNode (
463468 And (
464469 Txn .application_id () == Int (0 ),
465470 Txn .application_args [0 ] == MethodSignature (method_signature ),
466- Or (* [Txn .on_completion () == oc for oc in oc_create ]),
471+ Or (* [Txn .on_completion () == oc for oc in create_others ]),
467472 ),
468473 wrapped ,
469474 )
470475 )
471- if oc_call :
476+ if call_others :
472477 self .categorized_approval_ast .method_calls .append (
473478 CondNode (
474479 And (
475480 Txn .application_args [0 ] == MethodSignature (method_signature ),
476- Or (* [Txn .on_completion () == oc for oc in oc_call ]),
481+ Or (* [Txn .on_completion () == oc for oc in call_others ]),
477482 ),
478483 wrapped ,
479484 )
0 commit comments