33 FaultDisplacementFeature ,
44)
55from ....modelling .features import FeatureType
6- from ....modelling .features .fault ._fault_function import BaseFault , BaseFault3D
6+ from ....modelling .features .fault ._fault_function import BaseFault , BaseFault3D , FaultDisplacement
77from ....utils import getLogger , NegativeRegion , PositiveRegion
88from ....modelling .features import StructuralFrame
99
@@ -40,7 +40,7 @@ def __init__(
4040 StructuralFrame .__init__ (self , features , name , fold , model )
4141 self .type = FeatureType .FAULT
4242 self .displacement = displacement
43- self ._faultfunction = BaseFault .fault_displacement
43+ self ._faultfunction = BaseFault () .fault_displacement
4444 self .steps = steps
4545 self .regions = []
4646 self .faults_enabled = True
@@ -58,10 +58,13 @@ def faultfunction(self):
5858 def faultfunction (self , value ):
5959 if callable (value ):
6060 self ._faultfunction = value
61+ if issubclass (FaultDisplacement , type (value )):
62+ self ._faultfunction = value
63+
6164 elif isinstance (value , str ) and value == "BaseFault" :
62- self ._faultfunction = BaseFault .fault_displacement
65+ self ._faultfunction = BaseFault () .fault_displacement
6366 elif isinstance (value , str ) and value == "BaseFault3D" :
64- self ._faultfunction = BaseFault3D .fault_displacement
67+ self ._faultfunction = BaseFault3D () .fault_displacement
6568 else :
6669 raise ValueError ("Fault function must be a function or BaseFault" )
6770
0 commit comments