Skip to content

Commit

Permalink
updates to write_score_code to correct formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
djm21 committed Oct 28, 2024
1 parent 9158ab2 commit d55f57d
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions src/sasctl/pzmm/write_score_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def score(var1, var2, var3, var4):
input_var_list,
missing_values=missing_values,
dtype_list=input_dtypes_list,
preprocess_function=preprocess_function
)
self._predictions_to_metrics(
score_metrics,
Expand All @@ -266,6 +267,7 @@ def score(var1, var2, var3, var4):
missing_values=missing_values,
statsmodels_model="statsmodels_model" in kwargs,
tf_model="tf_keras_model" in kwargs or "tf_core_model" in kwargs,
preprocess_function=preprocess_function
)
# Include check for numpy values and a conversion operation as needed
self.score_code += (
Expand Down Expand Up @@ -814,14 +816,15 @@ def _predict_method(
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
self.score_code += self._wrap_indent_string(input_frame, 8)
self.score_code += f"\n{'':4})\n"
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
)

if missing_values:
self.score_code += (
f"{'':4}input_array = impute_missing_values(input_array)\n"
)
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
)
self.score_code += (
f"{'':4}column_types = {column_types}\n"
f"{'':4}h2o_array = h2o.H2OFrame(input_array, "
Expand Down Expand Up @@ -860,14 +863,14 @@ def _predict_method(
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
self.score_code += self._wrap_indent_string(input_frame, 8)
self.score_code += f"\n{'':4})\n"
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
)
if missing_values:
self.score_code += (
f"{'':4}input_array = impute_missing_values(input_array)\n"
)
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
)
self.score_code += (
f"{'':4}prediction = model.{method.__name__}(input_array)\n"
)
Expand All @@ -885,14 +888,14 @@ def _predict_method(
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
self.score_code += self._wrap_indent_string(input_frame, 8)
self.score_code += f"\n{'':4})\n"
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
)
if missing_values:
self.score_code += (
f"{'':4}input_array = impute_missing_values(input_array)\n"
)
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
)
self.score_code += (
f"{'':4}prediction = model.{method.__name__}(input_array)\n\n"
f"{'':4} # Check if model returns logits or probabilities\n"
Expand Down Expand Up @@ -921,14 +924,14 @@ def _predict_method(
input_frame = f'{{{", ".join(input_dict)}}}, index=index'
self.score_code += self._wrap_indent_string(input_frame, 8)
self.score_code += f"\n{'':4})\n"
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)"
)
if missing_values:
self.score_code += (
f"{'':4}input_array = impute_missing_values(input_array)\n"
)
if preprocess_function:
self.score_code += (
f"{'':4}input_array = {preprocess_function.__name__}(input_array)\n"
)
self.score_code += (
f"{'':4}prediction = model.{method.__name__}(input_array).tolist()\n"
)
Expand Down

0 comments on commit d55f57d

Please sign in to comment.