diff --git a/CHANGELOG.md b/CHANGELOG.md index 11cb0b80..116966d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log -## Release 1.3.1 (TBD) +## Release 1.3.1 (10/29/23) ### Added @@ -8,9 +8,11 @@ - Generator/Streaming handlers supported with local testing - [BETA] CLI DevEx functionality to create development projects. +--- + ## Release 1.3.0 (10/12/23) -### Changes +### Changed - Backwards compatibility with Python >= 3.8 - Consolidated install dependencies to `requirements.txt` @@ -23,7 +25,7 @@ ## Release 1.2.6 (10/6/23) -### Changes +### Changed - Force `urllib3` logging to `WARNING` level to avoid spamming the console if global logging level is set to `DEBUG`. @@ -40,7 +42,7 @@ ## ~~Release (Patch) 1.2.3 (10/4/23)~~ Replaced by 1.2.5 -### Bug Fix +### Fixed - Job outputs that were not dictionaries, bool, or str were swallowed by the serverless worker. This has been fixed. @@ -55,7 +57,7 @@ - `network_volume_id` can now be passed in when creating new pods, correct data center is automatically selected. - `template_id` can now be passed in when creating new pods. -### Changes +### Changed - Dependencies updated to latest versions. - Reduced circular imports for version reference. @@ -91,7 +93,7 @@ - Can generate a credentials file from the CLI to store your API key. - `get_gpu` now supports `gpu_quantity` as a parameter. -### Changes +### Changed - Minimized the use of pytests in favor of unittests. - Re-named `api_wrapper` to `api` for consistency. diff --git a/docs/cli/demos/config.gif b/docs/cli/demos/config.gif index 74a2ef8e..28280be0 100644 Binary files a/docs/cli/demos/config.gif and b/docs/cli/demos/config.gif differ diff --git a/docs/cli/demos/help.gif b/docs/cli/demos/help.gif index 8bc28bb8..39150ae5 100644 Binary files a/docs/cli/demos/help.gif and b/docs/cli/demos/help.gif differ diff --git a/docs/cli/demos/ssh.gif b/docs/cli/demos/ssh.gif index 39f96e7e..c1954de0 100644 Binary files a/docs/cli/demos/ssh.gif and b/docs/cli/demos/ssh.gif differ diff --git a/examples/api/create_endpoint.py b/examples/api/create_endpoint.py new file mode 100644 index 00000000..4d7fedb4 --- /dev/null +++ b/examples/api/create_endpoint.py @@ -0,0 +1,30 @@ +""" Example of creating an endpoint with the Runpod API. """ + +import runpod + +# Set your global API key with `runpod config` or uncomment the line below: +# runpod.api_key = "YOUR_RUNPOD_API_KEY" + +try: + + new_template = runpod.create_template( + name="test", + image_name="runpod/base:0.1.0", + is_serverless=True + ) + + print(new_template) + + new_endpoint = runpod.create_endpoint( + name="test", + template_id=new_template["id"], + gpu_ids="AMPERE_16", + workers_min=0, + workers_max=1 + ) + + print(new_endpoint) + +except runpod.error.QueryError as err: + print(err) + print(err.query) diff --git a/examples/api/create_template.py b/examples/api/create_template.py new file mode 100644 index 00000000..6630ccd2 --- /dev/null +++ b/examples/api/create_template.py @@ -0,0 +1,19 @@ +""" Example of creating a template with the Runpod API. """ + +import runpod + +# Set your global API key with `runpod config` or uncomment the line below: +# runpod.api_key = "YOUR_RUNPOD_API_KEY" + +try: + + new_template = runpod.create_template( + name="test", + image_name="runpod/base:0.1.0" + ) + + print(new_template) + +except runpod.error.QueryError as err: + print(err) + print(err.query) diff --git a/runpod/__init__.py b/runpod/__init__.py index 5bad0843..a9f35d21 100644 --- a/runpod/__init__.py +++ b/runpod/__init__.py @@ -9,9 +9,10 @@ from .version import __version__ from .api.ctl_commands import( get_user, update_user_settings, - get_gpus, get_gpu, - get_pods, get_pod, - create_pod, stop_pod, resume_pod, terminate_pod + get_gpu, get_gpus, + get_pod, get_pods, create_pod, stop_pod, resume_pod, terminate_pod, + create_template, + create_endpoint ) from .cli.groups.config.functions import set_credentials, check_credentials, get_credentials diff --git a/runpod/api/ctl_commands.py b/runpod/api/ctl_commands.py index b7a17235..8f681b44 100644 --- a/runpod/api/ctl_commands.py +++ b/runpod/api/ctl_commands.py @@ -11,6 +11,10 @@ from .queries import pods as pod_queries from .graphql import run_graphql_query from .mutations import pods as pod_mutations +from .mutations import endpoints as endpoint_mutations + +# Templates +from .mutations import templates as template_mutations def get_user() -> dict: ''' @@ -188,3 +192,73 @@ def terminate_pod(pod_id: str): run_graphql_query( pod_mutations.generate_pod_terminate_mutation(pod_id) ) + + +def create_template( + name:str, image_name:str, docker_start_cmd:str=None, + container_disk_in_gb:int=10, volume_in_gb:int=None, volume_mount_path:str=None, + ports:str=None, env:dict=None, is_serverless:bool=False +): + ''' + Create a template + + :param name: the name of the template + :param image_name: the name of the docker image to be used by the template + :param docker_start_cmd: the command to start the docker container with + :param container_disk_in_gb: how big should the container disk be + :param volume_in_gb: how big should the volume be + :param ports: the ports to open in the pod, example format - "8888/http,666/tcp" + :param volume_mount_path: where to mount the volume? + :param env: the environment variables to inject into the pod, + for example {EXAMPLE_VAR:"example_value", EXAMPLE_VAR2:"example_value 2"}, will + inject EXAMPLE_VAR and EXAMPLE_VAR2 into the pod with the mentioned values + :param is_serverless: is the template serverless? + + :example: + + >>> template_id = runpod.create_template("test", "runpod/stack", "python3 main.py") + ''' + raw_response = run_graphql_query( + template_mutations.generate_pod_template( + name, image_name, docker_start_cmd, + container_disk_in_gb, volume_in_gb, volume_mount_path, + ports, env, is_serverless + ) + ) + + return raw_response["data"]["saveTemplate"] + +def create_endpoint( + name:str, template_id:str, gpu_ids:str="AMPERE_16", + network_volume_id:str=None, locations:str=None, + idle_timeout:int=5, scaler_type:str="QUEUE_DELAY", scaler_value:int=4, + workers_min:int=0, workers_max:int=3 +): + ''' + Create an endpoint + + :param name: the name of the endpoint + :param template_id: the id of the template to use for the endpoint + :param gpu_ids: the ids of the GPUs to use for the endpoint + :param network_volume_id: the id of the network volume to use for the endpoint + :param locations: the locations to use for the endpoint + :param idle_timeout: the idle timeout for the endpoint + :param scaler_type: the scaler type for the endpoint + :param scaler_value: the scaler value for the endpoint + :param workers_min: the minimum number of workers for the endpoint + :param workers_max: the maximum number of workers for the endpoint + + :example: + + >>> endpoint_id = runpod.create_endpoint("test", "template_id") + ''' + raw_response = run_graphql_query( + endpoint_mutations.generate_endpoint_mutation( + name, template_id, gpu_ids, + network_volume_id, locations, + idle_timeout, scaler_type, scaler_value, + workers_min, workers_max + ) + ) + + return raw_response["data"]["saveEndpoint"] diff --git a/runpod/api/graphql.py b/runpod/api/graphql.py index f64adbc9..f42f150e 100644 --- a/runpod/api/graphql.py +++ b/runpod/api/graphql.py @@ -27,6 +27,9 @@ def run_graphql_query(query: str) -> Dict[str, Any]: raise error.AuthenticationError("Unauthorized request, please check your API key.") if "errors" in response.json(): - raise error.QueryError(response.json()["errors"][0]["message"]) + raise error.QueryError( + response.json()["errors"][0]["message"], + query + ) return response.json() diff --git a/runpod/api/mutations/endpoints.py b/runpod/api/mutations/endpoints.py new file mode 100644 index 00000000..503e94dc --- /dev/null +++ b/runpod/api/mutations/endpoints.py @@ -0,0 +1,59 @@ +""" RunPod | API Wrapper | Mutations | Endpoints """ + +# pylint: disable=too-many-arguments + +def generate_endpoint_mutation( + name:str, template_id:str, gpu_ids:str="AMPERE_16", + network_volume_id:str=None, locations:str=None, + idle_timeout:int=5, scaler_type:str="QUEUE_DELAY", scaler_value:int=4, + workers_min:int=0, workers_max:int=3 +): + """ Generate a string for a GraphQL mutation to create a new endpoint. """ + input_fields = [] + + # ------------------------------ Required Fields ----------------------------- # + input_fields.append(f'name: "{name}"') + input_fields.append(f'templateId: "{template_id}"') + input_fields.append(f'gpuIds: "{gpu_ids}"') + + # ------------------------------ Optional Fields ----------------------------- # + if network_volume_id is not None: + input_fields.append(f'networkVolumeId: "{network_volume_id}"') + else: + input_fields.append('networkVolumeId: ""') + + if locations is not None: + input_fields.append(f'locations: "{locations}"') + else: + input_fields.append('locations: ""') + + input_fields.append(f'idleTimeout: {idle_timeout}') + input_fields.append(f'scalerType: "{scaler_type}"') + input_fields.append(f'scalerValue: {scaler_value}') + input_fields.append(f'workersMin: {workers_min}') + input_fields.append(f'workersMax: {workers_max}') + + # Format the input fields into a string + input_fields_string = ", ".join(input_fields) + + return f""" + mutation {{ + saveEndpoint( + input: {{ + {input_fields_string} + }} + ) {{ + id + name + templateId + gpuIds + networkVolumeId + locations + idleTimeout + scalerType + scalerValue + workersMin + workersMax + }} + }} + """ diff --git a/runpod/api/mutations/pods.py b/runpod/api/mutations/pods.py index e6219fab..63a3e098 100644 --- a/runpod/api/mutations/pods.py +++ b/runpod/api/mutations/pods.py @@ -62,13 +62,7 @@ def generate_pod_deployment_mutation( input_fields.append(f'templateId: "{template_id}"') if network_volume_id is not None: - # network_volume_fragment = f''' - # networkVolume: {{ - # id: "{network_volume_id}" - # }} - # ''' - network_volume_fragment = f'networkVolumeId: "{network_volume_id}"' - input_fields.append(network_volume_fragment) + input_fields.append(f'networkVolumeId: "{network_volume_id}"') # Format input fields input_string = ", ".join(input_fields) diff --git a/runpod/api/mutations/templates.py b/runpod/api/mutations/templates.py new file mode 100644 index 00000000..17e660cc --- /dev/null +++ b/runpod/api/mutations/templates.py @@ -0,0 +1,83 @@ +""" RunPod | API Wrapper | Mutations | Templates """ + +# pylint: disable=too-many-arguments + +def generate_pod_template( + name:str, image_name:str, docker_start_cmd:str=None, + container_disk_in_gb:int=10, volume_in_gb:int=None, volume_mount_path:str=None, + ports:str=None, env:dict=None, is_serverless:bool=False +): + """ Generate a string for a GraphQL mutation to create a new pod template. """ + input_fields = [] + + # ------------------------------ Required Fields ----------------------------- # + input_fields.append(f'name: "{name}"') + input_fields.append(f'imageName: "{image_name}"') + + # ------------------------------ Optional Fields ----------------------------- # + if docker_start_cmd is not None: + docker_start_cmd = docker_start_cmd.replace('"', '\\"') + input_fields.append(f'dockerArgs: "{docker_start_cmd}"') + else: + input_fields.append('dockerArgs: ""') + + input_fields.append(f'containerDiskInGb: {container_disk_in_gb}') + + if volume_in_gb is not None: + input_fields.append(f'volumeInGb: {volume_in_gb}') + else: + input_fields.append('volumeInGb: 0') + + if volume_mount_path is not None: + input_fields.append(f'volumeMountPath: "{volume_mount_path}"') + + if ports is not None: + ports = ports.replace(" ", "") + input_fields.append(f'ports: "{ports}"') + else: + input_fields.append('ports: ""') + + if env is not None: + env_string = ", ".join( + [f'{{ key: "{key}", value: "{value}" }}' for key, value in env.items()]) + input_fields.append(f"env: [{env_string}]") + else: + input_fields.append('env: []') + + + if is_serverless: + input_fields.append('isServerless: true') + else: + input_fields.append('isServerless: false') + + # ------------------------------ Enforced Fields ----------------------------- # + input_fields.append('startSsh: true') + input_fields.append('isPublic: false') + input_fields.append('readme: ""') + + # Format the input fields into a string + input_fields_string = ", ".join(input_fields) + + return f""" + mutation {{ + saveTemplate( + input: {{ + {input_fields_string} + }} + ) {{ + id + name + imageName + dockerArgs + containerDiskInGb + volumeInGb + volumeMountPath + ports + env {{ + key + value + }} + isServerless + }} + }} + """ diff --git a/runpod/cli/groups/project/commands.py b/runpod/cli/groups/project/commands.py index bc504ae5..bcad4176 100644 --- a/runpod/cli/groups/project/commands.py +++ b/runpod/cli/groups/project/commands.py @@ -5,7 +5,9 @@ import os import click -from .functions import create_new_project, launch_project, start_project_api +from .functions import ( + create_new_project, launch_project, start_project_api, create_project_endpoint +) from .helpers import validate_project_name @click.group('project') @@ -84,3 +86,18 @@ def start_project_pod(): ''' click.echo("Starting project API server...") start_project_api() + + +# ------------------------------ Deploy Project ------------------------------ # +@project_cli.command('deploy') +def deploy_project(): + """ Deploy the project to RunPod. """ + click.echo("Deploying project...") + + endpoint_id = create_project_endpoint() + + click.echo(f"Project deployed successfully! Endpoint ID: {endpoint_id}") + click.echo("The following urls are available:") + click.echo(f" - https://api.runpod.ai/v2/{endpoint_id}/runsync") + click.echo(f" - https://api.runpod.ai/v2/{endpoint_id}/run") + click.echo(f" - https://api.runpod.ai/v2/{endpoint_id}/health") diff --git a/runpod/cli/groups/project/functions.py b/runpod/cli/groups/project/functions.py index 0eef6f2a..7e0f7323 100644 --- a/runpod/cli/groups/project/functions.py +++ b/runpod/cli/groups/project/functions.py @@ -10,7 +10,7 @@ import tomlkit from tomlkit import document, comment, table, nl -from runpod import get_pod, __version__ +from runpod import __version__, get_pod, create_template, create_endpoint from runpod.cli.utils.ssh_cmd import SSHConnection from .helpers import get_project_pod, copy_template_files, attempt_pod_launch, load_project_config from ...utils.rp_sync import sync_directory @@ -281,3 +281,39 @@ def start_project_api(): ssh_conn.run_commands(launch_api_server) finally: ssh_conn.close() + + +# ------------------------------ Deploy Project ------------------------------ # +def create_project_endpoint(): + """ Create a project endpoint. + - Create a serverless template for the project + - Create a new endpoint using the template + """ + config = load_project_config() + + environment_variables = {} + for variable in config['project']['env_vars']: + environment_variables[variable] = config['project']['env_vars'][variable] + + # Construct the docker start command + docker_start_cmd_prefix = 'bash -c "' + activate_cmd = f'. /runpod-volume/{config["project"]["uuid"]}/venv/bin/activate' + python_cmd = f'python -u /runpod-volume/{config["project"]["uuid"]}/{config["project"]["name"]}/{config["runtime"]["handler_path"]}' # pylint: disable=line-too-long + docker_start_cmd_suffix = '"' + docker_start_cmd = docker_start_cmd_prefix + activate_cmd + ' && ' + python_cmd + docker_start_cmd_suffix # pylint: disable=line-too-long + + project_endpoint_template = create_template( + name = f'{config["project"]["name"]}-endpoint | {config["project"]["uuid"]}', + image_name = config['project']['base_image'], + container_disk_in_gb = config['project']['container_disk_size_gb'], + docker_start_cmd = docker_start_cmd, + env = environment_variables, is_serverless = True + ) + + deployed_endpoint = create_endpoint( + name = f'{config["project"]["name"]}-endpoint | {config["project"]["uuid"]}', + template_id = project_endpoint_template['id'], + network_volume_id=config['project']['storage_id'], + ) + + return deployed_endpoint['id'] diff --git a/runpod/error.py b/runpod/error.py index 83b4313a..4e4d92ad 100644 --- a/runpod/error.py +++ b/runpod/error.py @@ -13,9 +13,12 @@ class RunPodError(Exception): ''' def __init__(self, message: Optional[str] = None): super().__init__(message) - self.message = message + def __str__(self): + if self.message: + return self.message + return super().__str__() class AuthenticationError(RunPodError): @@ -28,3 +31,6 @@ class QueryError(RunPodError): ''' Raised when a GraphQL query fails ''' + def __init__(self, message: Optional[str] = None, query: Optional[str] = None): + super().__init__(message) + self.query = query diff --git a/tests/test_api/test_ctl_commands.py b/tests/test_api/test_ctl_commands.py index 39e71fe3..5fe311b4 100644 --- a/tests/test_api/test_ctl_commands.py +++ b/tests/test_api/test_ctl_commands.py @@ -304,3 +304,51 @@ def test_get_pod(self): pods = ctl_commands.get_pod("POD_ID") self.assertEqual(pods["id"], "POD_ID") + + def test_create_template(self): + ''' + Tests create_template + ''' + with patch("runpod.api.graphql.requests.post") as patch_request, \ + patch("runpod.api.ctl_commands.get_gpu") as patch_get_gpu: + + patch_request.return_value.json.return_value = { + "data": { + "saveTemplate": { + "id": "TEMPLATE_ID" + } + } + } + + patch_get_gpu.return_value = None + + template = ctl_commands.create_template( + name="TEMPLATE_NAME", + image_name="IMAGE_NAME" + ) + + self.assertEqual(template["id"], "TEMPLATE_ID") + + def test_create_endpoint(self): + ''' + Tests create_endpoint + ''' + with patch("runpod.api.graphql.requests.post") as patch_request, \ + patch("runpod.api.ctl_commands.get_gpu") as patch_get_gpu: + + patch_request.return_value.json.return_value = { + "data": { + "saveEndpoint": { + "id": "ENDPOINT_ID" + } + } + } + + patch_get_gpu.return_value = None + + endpoint = ctl_commands.create_endpoint( + name="ENDPOINT_NAME", + template_id="TEMPLATE_ID" + ) + + self.assertEqual(endpoint["id"], "ENDPOINT_ID") diff --git a/tests/test_api/test_mutation_endpoints.py b/tests/test_api/test_mutation_endpoints.py new file mode 100644 index 00000000..f1db3a03 --- /dev/null +++ b/tests/test_api/test_mutation_endpoints.py @@ -0,0 +1,34 @@ +"""Tests for the endpoint mutation generation.""" + +import unittest + +from runpod.api.mutations.endpoints import generate_endpoint_mutation + +class TestGenerateEndpointMutation(unittest.TestCase): + """Tests for the endpoint mutation generation.""" + + def test_required_fields(self): + """Test the required fields.""" + result = generate_endpoint_mutation("test_name", "test_template_id") + self.assertIn('name: "test_name"', result) + self.assertIn('templateId: "test_template_id"', result) + self.assertIn('gpuIds: "AMPERE_16"', result) # Default value + self.assertIn('networkVolumeId: ""', result) # Default value + self.assertIn('locations: ""', result) # Default value + + def test_all_fields(self): + """Test all the fields.""" + result = generate_endpoint_mutation( + "test_name", "test_template_id", "AMPERE_20", + "test_volume_id", "US_WEST", 10, "WORKER_COUNT", 5, 2, 4 + ) + self.assertIn('name: "test_name"', result) + self.assertIn('templateId: "test_template_id"', result) + self.assertIn('gpuIds: "AMPERE_20"', result) + self.assertIn('networkVolumeId: "test_volume_id"', result) + self.assertIn('locations: "US_WEST"', result) + self.assertIn('idleTimeout: 10', result) + self.assertIn('scalerType: "WORKER_COUNT"', result) + self.assertIn('scalerValue: 5', result) + self.assertIn('workersMin: 2', result) + self.assertIn('workersMax: 4', result) diff --git a/tests/test_api/test_mutations_templates.py b/tests/test_api/test_mutations_templates.py new file mode 100644 index 00000000..eff253de --- /dev/null +++ b/tests/test_api/test_mutations_templates.py @@ -0,0 +1,35 @@ +""" Unit tests for the function generate_pod_template in the file api_wrapper.py """ + +import unittest + +from runpod.api.mutations.templates import generate_pod_template + +class TestGeneratePodTemplate(unittest.TestCase): + """ Unit tests for the function generate_pod_template in the file api_wrapper.py """ + + def test_basic_required_fields(self): + """ Test the basic required fields are present in the generated template """ + result = generate_pod_template("test_name", "test_image_name") + self.assertIn('name: "test_name"', result) + self.assertIn('imageName: "test_image_name"', result) + self.assertIn('dockerArgs: ""', result) # Defaults + self.assertIn('containerDiskInGb: 10', result) # Defaults + self.assertIn('volumeInGb: 0', result) # Defaults + self.assertIn('ports: ""', result) # Defaults + self.assertIn('env: []', result) # Defaults + self.assertIn('isServerless: false', result) # Defaults + + def test_optional_fields(self): + """ Test the optional fields are present in the generated template """ + result = generate_pod_template( + "test_name", "test_image_name", docker_start_cmd="test_cmd", + volume_in_gb=5, volume_mount_path="/path/to/volume", + ports="8000, 8001", env={"VAR1": "val1", "VAR2": "val2"}, is_serverless=True + ) + self.assertIn('dockerArgs: "test_cmd"', result) + self.assertIn('volumeInGb: 5', result) + self.assertIn('volumeMountPath: "/path/to/volume"', result) + self.assertIn('ports: "8000,8001"', result) + self.assertIn( + 'env: [{ key: "VAR1", value: "val1" }, { key: "VAR2", value: "val2" }]', result) + self.assertIn('isServerless: true', result) diff --git a/tests/test_cli/test_cli_groups/test_project_commands.py b/tests/test_cli/test_cli_groups/test_project_commands.py index 4ddfceb6..d5af3e83 100644 --- a/tests/test_cli/test_cli_groups/test_project_commands.py +++ b/tests/test_cli/test_cli_groups/test_project_commands.py @@ -6,7 +6,9 @@ from click.testing import CliRunner from runpod.cli.groups.project.commands import ( - new_project_wizard, launch_project_pod, start_project_pod) + new_project_wizard, launch_project_pod, start_project_pod, + deploy_project +) class TestProjectCLI(unittest.TestCase): ''' A collection of tests for the Project CLI commands. ''' @@ -83,3 +85,22 @@ def test_start_project_pod(self): self.assertEqual(result.exit_code, 0) self.assertIn("Starting project API server...", result.output) + + + @patch('runpod.cli.groups.project.commands.click.echo') + @patch('runpod.cli.groups.project.commands.create_project_endpoint') + def test_deploy_project(self, mock_create_project_endpoint, mock_click_echo): + """ Test the deploy_project function. """ + mock_create_project_endpoint.return_value = 'test_endpoint_id' + + result = self.runner.invoke(deploy_project) + + mock_create_project_endpoint.assert_called_once() + + mock_click_echo.assert_any_call("Deploying project...") + mock_click_echo.assert_any_call("The following urls are available:") + mock_click_echo.assert_any_call(" - https://api.runpod.ai/v2/test_endpoint_id/runsync") + mock_click_echo.assert_any_call(" - https://api.runpod.ai/v2/test_endpoint_id/run") + mock_click_echo.assert_any_call(" - https://api.runpod.ai/v2/test_endpoint_id/health") + + self.assertEqual(result.exit_code, 0) diff --git a/tests/test_cli/test_cli_groups/test_project_functions.py b/tests/test_cli/test_cli_groups/test_project_functions.py index c29a66ce..75a30fe9 100644 --- a/tests/test_cli/test_cli_groups/test_project_functions.py +++ b/tests/test_cli/test_cli_groups/test_project_functions.py @@ -8,7 +8,8 @@ from runpod.cli.groups.project.functions import( STARTER_TEMPLATES, create_new_project, - launch_project, start_project_api + launch_project, start_project_api, + create_project_endpoint ) class TestCreateNewProject(unittest.TestCase): @@ -229,3 +230,45 @@ def test_start_project_api_pod_not_found(self, mock_ssh_connection, mock_get_pro ) assert mock_ssh_connection.called is False assert mock_get_project_pod.called + +class TestCreateProjectEndpoint(unittest.TestCase): + """ Test the create_project_endpoint function. """ + + @patch('runpod.cli.groups.project.functions.load_project_config') + @patch('runpod.cli.groups.project.functions.create_template') + @patch('runpod.cli.groups.project.functions.create_endpoint') + def test_create_project_endpoint(self, mock_create_endpoint, + mock_create_template, mock_load_project_config): + """ Test that a project endpoint is created successfully. """ + mock_load_project_config.return_value = { + 'project': { + 'name': 'test_project', + 'uuid': '123456', + 'env_vars': {'TEST_VAR': 'value'}, + 'base_image': 'test_image', + 'container_disk_size_gb': 10, + 'storage_id': 'test_storage_id', + }, + 'runtime': { + 'handler_path': 'handler.py' + } + } + mock_create_template.return_value = {'id': 'test_template_id'} + mock_create_endpoint.return_value = {'id': 'test_endpoint_id'} + + result = create_project_endpoint() + + self.assertEqual(result, 'test_endpoint_id') + mock_create_template.assert_called_once_with( + name='test_project-endpoint | 123456', + image_name='test_image', + container_disk_in_gb=10, + docker_start_cmd='bash -c ". /runpod-volume/123456/venv/bin/activate && python -u /runpod-volume/123456/test_project/handler.py"', # pylint: disable=line-too-long + env={'TEST_VAR': 'value'}, + is_serverless=True + ) + mock_create_endpoint.assert_called_once_with( + name='test_project-endpoint | 123456', + template_id='test_template_id', + network_volume_id='test_storage_id' + ) diff --git a/tests/test_error.py b/tests/test_error.py new file mode 100644 index 00000000..52068299 --- /dev/null +++ b/tests/test_error.py @@ -0,0 +1,40 @@ +"""Unit tests for the error classes in the runpod.error module.""" + +import unittest + +# Assuming the error classes are in a file named 'error.py' +from runpod.error import RunPodError, AuthenticationError, QueryError + +class TestErrorClasses(unittest.TestCase): + """Unit tests for the error classes in the runpod.error module.""" + + def test_run_pod_error_with_message(self): + """Test the RunPodError class with a message.""" + error_msg = "An error occurred" + err = RunPodError(error_msg) + self.assertEqual(str(err), error_msg) + + def test_run_pod_error_without_message(self): + """Test the RunPodError class without a message.""" + err = RunPodError() + self.assertEqual(str(err), 'None') + + def test_authentication_error(self): + """Test the AuthenticationError class.""" + error_msg = "Authentication failed" + err = AuthenticationError(error_msg) + self.assertEqual(str(err), error_msg) + + def test_query_error_with_message_and_query(self): + """Test the QueryError class with a message and query.""" + error_msg = "Query failed" + query_str = "SELECT * FROM some_table WHERE condition" + err = QueryError(error_msg, query_str) + self.assertEqual(str(err), error_msg) + self.assertEqual(err.query, query_str) + + def test_query_error_without_message_and_query(self): + """Test the QueryError class without a message or query.""" + err = QueryError() + self.assertEqual(str(err), 'None') + self.assertIsNone(err.query)