Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12
build:build --config=_build

# Config to sync files
run:pre_build --config=_build --config=py3.9
run:pre_build --config=_build --config=py3.10

# Config to run type check
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10

# Config to build the doc
build:docs --config=_all --config=py3.9
build:docs --config=_all --config=py3.10

# Public the extended setting

Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Release History

## 1.14.0

### Bug Fixes

### Behavior Changes

### New Features

* ML Job: The `additional_payloads` argument is now **deprecated** in favor of `imports`.

## 1.13.0

### Bug Fixes
Expand Down
2 changes: 1 addition & 1 deletion bazel/environments/fetch_conda_env_config.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ load("//bazel/platforms:optional_dependency_groups.bzl", "OPTIONAL_DEPENDENCY_GR
def _fetch_conda_env_config_impl(rctx):
# read the particular environment variable we are interested in
env_name = rctx.os.environ.get("BAZEL_CONDA_ENV_NAME", "core").lower()
python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.9").lower()
python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.10").lower()

# necessary to create empty BUILD file for this rule
# which will be located somewhere in the Bazel build files
Expand Down
6 changes: 3 additions & 3 deletions bazel/requirements/templates/bazelrc.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12
build:build --config=_build

# Config to sync files
run:pre_build --config=_build --config=py3.9
run:pre_build --config=_build --config=py3.10

# Config to run type check
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9
build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10

# Config to build the doc
build:docs --config=_all --config=py3.9
build:docs --config=_all --config=py3.10

# Public the extended setting

Expand Down
2 changes: 1 addition & 1 deletion ci/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ WITH_SNOWPARK=false
WITH_SPCS_IMAGE=false
RUN_GRYPE=false
MODE="continuous_run"
PYTHON_VERSION=3.9
PYTHON_VERSION=3.10
PYTHON_ENABLE_SCRIPT="bin/activate"
SNOWML_DIR="snowml"
SNOWPARK_DIR="snowpark-python"
Expand Down
2 changes: 1 addition & 1 deletion ci/conda_recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ build:
noarch: python
package:
name: snowflake-ml-python
version: 1.13.0
version: 1.14.0
requirements:
build:
- python
Expand Down
2 changes: 1 addition & 1 deletion snowflake/ml/jobs/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
DEFAULT_IMAGE_TAG = "1.6.2"
DEFAULT_IMAGE_TAG = "1.8.0"
DEFAULT_ENTRYPOINT_PATH = "func.py"

# Percent of container memory to allocate for /dev/shm volume
Expand Down
16 changes: 9 additions & 7 deletions snowflake/ml/jobs/_utils/scripts/mljob_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
if payload_dir and payload_dir not in sys.path:
sys.path.insert(0, payload_dir)

# Create a Snowpark session before running the script
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
config = SnowflakeLoginOptions()
config["client_session_keep_alive"] = "True"
session = Session.builder.configs(config).create() # noqa: F841

try:

if main_func:
Expand All @@ -266,7 +260,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
finally:
# Restore original sys.argv
sys.argv = original_argv
session.close()


def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
Expand Down Expand Up @@ -297,6 +290,12 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
except ModuleNotFoundError:
warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)

# Create a Snowpark session before starting
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
config = SnowflakeLoginOptions()
config["client_session_keep_alive"] = "True"
session = Session.builder.configs(config).create() # noqa: F841

try:
# Wait for minimum required instances if specified
min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
Expand Down Expand Up @@ -352,6 +351,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
)

# Close the session after serializing the result
session.close()


if __name__ == "__main__":
# Parse command line arguments
Expand Down
16 changes: 12 additions & 4 deletions snowflake/ml/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def _service_spec(self) -> dict[str, Any]:
def _container_spec(self) -> dict[str, Any]:
"""Get the job's main container spec."""
containers = self._service_spec["spec"]["containers"]
if len(containers) == 1:
return cast(dict[str, Any], containers[0])
try:
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
except StopIteration:
Expand Down Expand Up @@ -163,7 +165,7 @@ def get_logs(
Returns:
The job's execution logs.
"""
logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
assert isinstance(logs, str) # mypy
if as_list:
return logs.splitlines()
Expand Down Expand Up @@ -281,7 +283,12 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:

@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
def _get_logs(
session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
session: snowpark.Session,
job_id: str,
limit: int = -1,
instance_id: Optional[int] = None,
container_name: str = constants.DEFAULT_CONTAINER_NAME,
verbose: bool = True,
) -> str:
"""
Retrieve the job's execution logs.
Expand All @@ -291,6 +298,7 @@ def _get_logs(
limit: The maximum number of lines to return. Negative values are treated as no limit.
session: The Snowpark session to use. If none specified, uses active session.
instance_id: Optional instance ID to get logs from a specific instance.
container_name: The container name to get logs from a specific container.
verbose: Whether to return the full log or just the portion between START and END messages.

Returns:
Expand All @@ -311,7 +319,7 @@ def _get_logs(
params: list[Any] = [
job_id,
0 if instance_id is None else instance_id,
constants.DEFAULT_CONTAINER_NAME,
container_name,
]
if limit > 0:
params.append(limit)
Expand All @@ -337,7 +345,7 @@ def _get_logs(
job_id,
limit=limit,
instance_id=instance_id if instance_id else 0,
container_name=constants.DEFAULT_CONTAINER_NAME,
container_name=container_name,
)
full_log = os.linesep.join(row[0] for row in logs)

Expand Down
11 changes: 9 additions & 2 deletions snowflake/ml/jobs/jobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from snowflake.snowpark import exceptions as sp_exceptions
from snowflake.snowpark.row import Row

SERVICE_SPEC = """
spec:
containers:
- name: main
image: test-image
"""


class JobTest(parameterized.TestCase):
@parameterized.named_parameters( # type: ignore[misc]
Expand Down Expand Up @@ -83,7 +90,7 @@ def test_get_logs_negative(self) -> None:

def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any:
if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"):
return [Row(target_instances=2)]
return [Row(target_instances=2, spec=SERVICE_SPEC)]
else:
raise sp_exceptions.SnowparkSQLException("Waiting to start, Container Status: PENDING")

Expand All @@ -97,7 +104,7 @@ def test_get_logs_from_event_table(self) -> None:
def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any:
if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"):
return [
Row(target_instances=2),
Row(target_instances=2, spec=SERVICE_SPEC),
]
elif query_str.startswith("SELECT VALUE FROM "):
return [
Expand Down
41 changes: 34 additions & 7 deletions snowflake/ml/jobs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def submit_file(
enable_metrics (bool): Whether to enable metrics publishing for the job.
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
spec_overrides (dict): A dictionary of overrides for the service spec.
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.

Returns:
An object representing the submitted job.
Expand Down Expand Up @@ -286,6 +287,7 @@ def submit_directory(
enable_metrics (bool): Whether to enable metrics publishing for the job.
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
spec_overrides (dict): A dictionary of overrides for the service spec.
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.

Returns:
An object representing the submitted job.
Expand Down Expand Up @@ -341,6 +343,7 @@ def submit_from_stage(
enable_metrics (bool): Whether to enable metrics publishing for the job.
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
spec_overrides (dict): A dictionary of overrides for the service spec.
imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.

Returns:
An object representing the submitted job.
Expand Down Expand Up @@ -404,6 +407,8 @@ def _submit_job(
"num_instances", # deprecated
"target_instances",
"min_instances",
"enable_metrics",
"query_warehouse",
],
)
def _submit_job(
Expand Down Expand Up @@ -447,6 +452,13 @@ def _submit_job(
)
target_instances = max(target_instances, kwargs.pop("num_instances"))

imports = None
if "additional_payloads" in kwargs:
logger.warning(
"'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
)
imports = kwargs.pop("additional_payloads")

# Use kwargs for less common optional parameters
database = kwargs.pop("database", None)
schema = kwargs.pop("schema", None)
Expand All @@ -457,10 +469,7 @@ def _submit_job(
spec_overrides = kwargs.pop("spec_overrides", None)
enable_metrics = kwargs.pop("enable_metrics", True)
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
additional_payloads = kwargs.pop("additional_payloads", None)

if additional_payloads:
logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
imports = kwargs.pop("imports", None) or imports

# Warn if there are unknown kwargs
if kwargs:
Expand Down Expand Up @@ -492,7 +501,7 @@ def _submit_job(
try:
# Upload payload
uploaded_payload = payload_utils.JobPayload(
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
).upload(session, stage_path)
except snowpark.exceptions.SnowparkSQLException as e:
if e.sql_error_code == 90106:
Expand All @@ -501,6 +510,22 @@ def _submit_job(
)
raise

# FIXME: Temporary patches, remove this after v1 is deprecated
if target_instances > 1:
default_spec_overrides = {
"spec": {
"endpoints": [
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
]
},
}
if spec_overrides:
spec_overrides = spec_utils.merge_patch(
default_spec_overrides, spec_overrides, display_name="spec_overrides"
)
else:
spec_overrides = default_spec_overrides

if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
# Add default env vars (extracted from spec_utils.generate_service_spec)
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
Expand Down Expand Up @@ -668,8 +693,10 @@ def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
session = session or get_active_session()
except snowpark.exceptions.SnowparkSessionException as e:
if "More than one active session" in e.message:
raise RuntimeError("Please specify the session as a parameter in API call")
raise RuntimeError(
"More than one active session is found. Please specify the session explicitly as a parameter"
) from None
if "No default Session is found" in e.message:
raise RuntimeError("Please create a session before API call")
raise RuntimeError("No active session is found. Please create a session") from None
raise
return session
1 change: 0 additions & 1 deletion snowflake/ml/lineage/lineage_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str)
raise NotImplementedError()

@telemetry.send_api_usage_telemetry(project=_PROJECT)
@snowpark._internal.utils.private_preview(version="1.5.3")
def lineage(
self,
direction: Literal["upstream", "downstream"] = "downstream",
Expand Down
2 changes: 1 addition & 1 deletion snowflake/ml/lineage/notebooks/ML Lineage Workflows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "py38_env",
"language": "python",
"name": "python3"
},
Expand Down
17 changes: 1 addition & 16 deletions snowflake/ml/model/_client/model/model_version_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def _enrich_inference_engine_args(
inference_engine_args: service_ops.InferenceEngineArgs,
gpu_requests: Optional[Union[str, int]] = None,
) -> Optional[service_ops.InferenceEngineArgs]:
"""Enrich inference engine args with model path and tensor parallelism settings.
"""Enrich inference engine args with tensor parallelism settings.

Args:
inference_engine_args: The original inference engine args
Expand All @@ -803,21 +803,6 @@ def _enrich_inference_engine_args(
if inference_engine_args.inference_engine_args_override is None:
inference_engine_args.inference_engine_args_override = []

# Get model stage path and strip off "snow://" prefix
model_stage_path = self._model_ops.get_model_version_stage_path(
database_name=None,
schema_name=None,
model_name=self._model_name,
version_name=self._version_name,
)

# Strip "snow://" prefix
if model_stage_path.startswith("snow://"):
model_stage_path = model_stage_path.replace("snow://", "", 1)

# Always overwrite the model key by appending
inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}")

gpu_count = None

# Set tensor-parallelism if gpu_requests is specified
Expand Down
Loading