Skip to content

Commit d55f57d

Browse files
committed
updates to write_score_code to correct formatting errors
1 parent 9158ab2 commit d55f57d

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def score(var1, var2, var3, var4):
250250
input_var_list,
251251
missing_values=missing_values,
252252
dtype_list=input_dtypes_list,
253+
preprocess_function=preprocess_function
253254
)
254255
self._predictions_to_metrics(
255256
score_metrics,
@@ -266,6 +267,7 @@ def score(var1, var2, var3, var4):
266267
missing_values=missing_values,
267268
statsmodels_model="statsmodels_model" in kwargs,
268269
tf_model="tf_keras_model" in kwargs or "tf_core_model" in kwargs,
270+
preprocess_function=preprocess_function
269271
)
270272
# Include check for numpy values and a conversion operation as needed
271273
self.score_code += (
@@ -814,14 +816,15 @@ def _predict_method(
814816
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
815817
self.score_code += self._wrap_indent_string(input_frame, 8)
816818
self.score_code += f"\n{'':4})\n"
817-
if preprocess_function:
818-
self.score_code += (
819-
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
820-
)
819+
821820
if missing_values:
822821
self.score_code += (
823822
f"{'':4}input_array = impute_missing_values(input_array)\n"
824823
)
824+
if preprocess_function:
825+
self.score_code += (
826+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
827+
)
825828
self.score_code += (
826829
f"{'':4}column_types = {column_types}\n"
827830
f"{'':4}h2o_array = h2o.H2OFrame(input_array, "
@@ -860,14 +863,14 @@ def _predict_method(
860863
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
861864
self.score_code += self._wrap_indent_string(input_frame, 8)
862865
self.score_code += f"\n{'':4})\n"
863-
if preprocess_function:
864-
self.score_code += (
865-
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
866-
)
867866
if missing_values:
868867
self.score_code += (
869868
f"{'':4}input_array = impute_missing_values(input_array)\n"
870869
)
870+
if preprocess_function:
871+
self.score_code += (
872+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
873+
)
871874
self.score_code += (
872875
f"{'':4}prediction = model.{method.__name__}(input_array)\n"
873876
)
@@ -885,14 +888,14 @@ def _predict_method(
885888
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
886889
self.score_code += self._wrap_indent_string(input_frame, 8)
887890
self.score_code += f"\n{'':4})\n"
888-
if preprocess_function:
889-
self.score_code += (
890-
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
891-
)
892891
if missing_values:
893892
self.score_code += (
894893
f"{'':4}input_array = impute_missing_values(input_array)\n"
895894
)
895+
if preprocess_function:
896+
self.score_code += (
897+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
898+
)
896899
self.score_code += (
897900
f"{'':4}prediction = model.{method.__name__}(input_array)\n\n"
898901
f"{'':4} # Check if model returns logits or probabilities\n"
@@ -921,14 +924,14 @@ def _predict_method(
921924
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
922925
self.score_code += self._wrap_indent_string(input_frame, 8)
923926
self.score_code += f"\n{'':4})\n"
924-
if preprocess_function:
925-
self.score_code += (
926-
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
927-
)
928927
if missing_values:
929928
self.score_code += (
930929
f"{'':4}input_array = impute_missing_values(input_array)\n"
931930
)
931+
if preprocess_function:
932+
self.score_code += (
933+
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
934+
)
932935
self.score_code += (
933936
f"{'':4}prediction = model.{method.__name__}(input_array).tolist()\n"
934937
)

0 commit comments

Comments
 (0)