Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch delta table not found error #1625

Merged
merged 17 commits into from
Oct 30, 2024
24 changes: 24 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
ClusterInvalidAccessMode,
DeltaTableNotFoundError,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
FaultyDataPrepCluster,
Expand Down Expand Up @@ -503,6 +504,29 @@ def fetch(
raise InsufficientPermissionsError(str(e)) from e
elif 'UC_NOT_ENABLED' in str(e):
raise UCNotEnabledError() from e
elif 'DELTA_TABLE_NOT_FOUND' in str(e):
err_str = str(e)
# Error string should be in this format:
# ---
# Error processing `catalog`.`volume_name`.`table_name`:
# [DELTA_TABLE_NOT_FOUND] Delta table `volume_name`.`table_name`
# doesn't exist.
# ---
parts = err_str.split('`')
if len(parts) < 7:
# Failed to parse error, our codebase is brittle
# with respect to the string representations of
# errors in the spark library.
catalog_name, volume_name, table_name = ['unknown'] * 3
else:
catalog_name = parts[1]
volume_name = parts[3]
table_name = parts[5]
raise DeltaTableNotFoundError(
catalog_name,
volume_name,
table_name,
) from e

if isinstance(e, InsufficientPermissionsError):
raise
Expand Down
22 changes: 21 additions & 1 deletion llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
'MisconfiguredHfDatasetError',
'DatasetTooSmallError',
'RunTimeoutError',
'UCNotEnabledError',
'DeltaTableNotFoundError',
]

ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
Expand Down Expand Up @@ -530,5 +532,23 @@ class UCNotEnabledError(UserError):
"""Error thrown when user does not have UC enabled on their cluster."""

def __init__(self) -> None:
message = f'Unity Catalog is not enabled on your cluster.'
message = 'Unity Catalog is not enabled on your cluster.'
super().__init__(message)


class DeltaTableNotFoundError(UserError):
"""Error thrown when the delta table passed in training doesn't exist."""

def __init__(
self,
catalog_name: str,
volume_name: str,
table_name: str,
) -> None:
message = f'Your data path {catalog_name}.{volume_name}.{table_name} does not exist. Please double check your delta table name'
super().__init__(
message=message,
catalog_name=catalog_name,
volume_name=volume_name,
table_name=table_name,
)
76 changes: 57 additions & 19 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import MagicMock, mock_open, patch

import grpc
from pyspark.errors import AnalysisException

from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
FaultyDataPrepCluster,
Expand All @@ -19,6 +20,7 @@
iterative_combine_jsons,
run_query,
)
from llmfoundry.utils.exceptions import DeltaTableNotFoundError


class TestConvertDeltaToJsonl(unittest.TestCase):
Expand Down Expand Up @@ -139,25 +141,24 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any):

mock_listdir.assert_called_once_with(json_directory)
mock_file.assert_called()
"""
Diagnostic print
for call_args in mock_file().write.call_args_list:
print(call_args)
--------------------
call('{')
call('"key"')
call(': ')
call('"value"')
call('}')
call('\n')
call('{')
call('"key"')
call(': ')
call('"value"')
call('}')
call('\n')
--------------------
"""
# Diagnostic print
# for call_args in mock_file().write.call_args_list:
# print(call_args)
# --------------------
# call('{')
# call('"key"')
# call(': ')
# call('"value"')
# call('}')
# call('\n')
# call('{')
# call('"key"')
# call(': ')
# call('"value"')
# call('}')
# call('\n')
# --------------------

self.assertEqual(mock_file().write.call_count, 2)

@patch(
Expand Down Expand Up @@ -582,3 +583,40 @@ def test_fetch_DT_grpc_error_handling(

# Verify that fetch was called
mock_fetch.assert_called_once()

@patch(
'llmfoundry.command_utils.data_prep.convert_delta_to_json.get_total_rows',
)
def test_fetch_nonexistent_table_error(
self,
mock_gtr: MagicMock,
):
# Create a spark.AnalysisException with specific details
analysis_exception = AnalysisException(
message='[DELTA_TABLE_NOT_FOUND] yada yada',
)

# Configure the fetch function to raise the AnalysisException
mock_gtr.side_effect = analysis_exception

# Test inputs
method = 'dbsql'
delta_table_name = 'test_table'
json_output_folder = '/tmp/to/jsonl'

# Act & Assert
with self.assertRaises(DeltaTableNotFoundError) as context:
fetch(
method=method,
tablename=delta_table_name,
json_output_folder=json_output_folder,
)

# Verify that the DeltaTableNotFoundError contains the expected message
self.assertIn(
'Please double check your delta table name',
str(context.exception),
)

# Verify that get_total_rows was called
mock_gtr.assert_called_once()
Loading