Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server side json variables #13500

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
client side variable changes (#13508)
  • Loading branch information
jakekaplan authored May 22, 2024
commit 90acd262a711a84f365e49d6287c5dec6034df49
48 changes: 44 additions & 4 deletions src/prefect/client/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from uuid import UUID, uuid4

import jsonschema
from pydantic.v1 import Field, root_validator, validator
import orjson
from pydantic.v1 import (
Field,
StrictBool,
StrictFloat,
StrictInt,
StrictStr,
root_validator,
validator,
)

import prefect.client.schemas.objects as objects
from prefect._internal.compatibility.deprecated import DeprecatedInfraOverridesField
Expand Down Expand Up @@ -708,14 +717,29 @@ class VariableCreate(ActionBaseModel):
examples=["my_variable"],
max_length=objects.MAX_VARIABLE_NAME_LENGTH,
)
value: str = Field(
value: Union[
StrictStr, StrictFloat, StrictBool, StrictInt, None, Dict[str, Any], List[Any]
] = Field(
default=...,
description="The value of the variable",
examples=["my-value"],
max_length=objects.MAX_VARIABLE_VALUE_LENGTH,
)
tags: Optional[List[str]] = Field(default=None)

@validator("value")
def validate_value(cls, v):
try:
json_string = orjson.dumps(v)
except orjson.JSONDecodeError:
raise ValueError("Variable value must be serializable to JSON.")

if len(json_string) > objects.MAX_VARIABLE_VALUE_LENGTH:
raise ValueError(
f"value must less than {objects.MAX_VARIABLE_VALUE_LENGTH} characters when serialized."
)

return v

# validators
_validate_name_format = validator("name", allow_reuse=True)(validate_variable_name)

Expand All @@ -729,14 +753,30 @@ class VariableUpdate(ActionBaseModel):
examples=["my_variable"],
max_length=objects.MAX_VARIABLE_NAME_LENGTH,
)
value: Optional[str] = Field(
value: Union[
StrictStr, StrictFloat, StrictBool, StrictInt, None, Dict[str, Any], List[Any]
] = Field(
default=None,
description="The value of the variable",
examples=["my-value"],
max_length=objects.MAX_VARIABLE_NAME_LENGTH,
)
tags: Optional[List[str]] = Field(default=None)

@validator("value")
def validate_value(cls, v):
try:
json_string = orjson.dumps(v)
except orjson.JSONDecodeError:
raise ValueError("Variable value must be serializable to JSON.")

if len(json_string) > objects.MAX_VARIABLE_VALUE_LENGTH:
raise ValueError(
f"value must less than {objects.MAX_VARIABLE_VALUE_LENGTH} characters when serialized."
)

return v

# validators
_validate_name_format = validator("name", allow_reuse=True)(validate_variable_name)

Expand Down
16 changes: 13 additions & 3 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

import orjson
import pendulum
from pydantic.v1 import Field, HttpUrl, root_validator, validator
from pydantic.v1 import (
Field,
HttpUrl,
StrictBool,
StrictFloat,
StrictInt,
StrictStr,
root_validator,
validator,
)
from typing_extensions import Literal

from prefect._internal.compatibility.deprecated import (
Expand Down Expand Up @@ -1478,11 +1487,12 @@ class Variable(ObjectBaseModel):
examples=["my_variable"],
max_length=MAX_VARIABLE_NAME_LENGTH,
)
value: str = Field(
value: Union[
StrictStr, StrictFloat, StrictBool, StrictInt, None, Dict[str, Any], List[Any]
] = Field(
default=...,
description="The value of the variable",
examples=["my_value"],
max_length=MAX_VARIABLE_VALUE_LENGTH,
)
tags: List[str] = Field(
default_factory=list,
Expand Down
101 changes: 59 additions & 42 deletions src/prefect/variables.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union

from prefect._internal.compatibility.deprecated import deprecated_callable
from prefect.client.schemas.actions import VariableCreate as VariableRequest
from prefect.client.schemas.actions import VariableUpdate as VariableUpdateRequest
from prefect.client.schemas.objects import Variable as VariableResponse
from prefect.client.utilities import get_or_create_client
from prefect.exceptions import ObjectNotFound
from prefect.utilities.asyncutils import sync_compatible


Expand All @@ -23,51 +24,60 @@ class Variable(VariableRequest):
async def set(
cls,
name: str,
value: str,
value: Union[str, int, float, bool, None, List[Any], Dict[str, Any]],
tags: Optional[List[str]] = None,
overwrite: bool = False,
) -> Optional[str]:
):
"""
Sets a new variable. If one exists with the same name, user must pass `overwrite=True`
Returns `True` if the variable was created or updated

```
from prefect.variables import Variable

@flow
def my_flow():
var = Variable.set(name="my_var",value="test_value", tags=["hi", "there"], overwrite=True)
Variable.set(name="my_var",value="test_value", tags=["hi", "there"], overwrite=True)
```
or
```
from prefect.variables import Variable

@flow
async def my_flow():
var = await Variable.set(name="my_var",value="test_value", tags=["hi", "there"], overwrite=True)
await Variable.set(name="my_var",value="test_value", tags=["hi", "there"], overwrite=True)
```
"""
client, _ = get_or_create_client()
variable = await client.read_variable_by_name(name)
var_dict = {"name": name, "value": value}
var_dict["tags"] = tags or []
var_dict = {"name": name, "value": value, "tags": tags or []}
if variable:
if not overwrite:
raise ValueError(
"You are attempting to save a variable with a name that is already in use. If you would like to overwrite the values that are saved, then call .set with `overwrite=True`."
"You are attempting to set a variable with a name that is already in use. "
"If you would like to overwrite it, pass `overwrite=True`."
)
var = VariableUpdateRequest(**var_dict)
await client.update_variable(variable=var)
variable = await client.read_variable_by_name(name)
await client.update_variable(variable=VariableUpdateRequest(**var_dict))
else:
var = VariableRequest(**var_dict)
variable = await client.create_variable(variable=var)

return variable if variable else None
await client.create_variable(variable=VariableRequest(**var_dict))

@classmethod
@sync_compatible
async def get(cls, name: str, default: Optional[str] = None) -> Optional[str]:
async def get(
cls,
name: str,
default: Union[str, int, float, bool, None, List[Any], Dict[str, Any]] = None,
as_object: bool = False,
) -> Union[
str, int, float, bool, None, List[Any], Dict[str, Any], VariableResponse
]:
"""
Get a variable by name. If doesn't exist return the default.
Get a variable's value by name.

If the variable does not exist, return the default value.

If `as_object=True`, return the full variable object. `default` is ignored in this case.

```
from prefect.variables import Variable

Expand All @@ -86,29 +96,36 @@ async def my_flow():
"""
client, _ = get_or_create_client()
variable = await client.read_variable_by_name(name)
return variable if variable else default
if as_object:
return variable

return variable.value if variable else default

@deprecated_callable(start_date="Apr 2024")
@sync_compatible
async def get(name: str, default: Optional[str] = None) -> Optional[str]:
"""
Get a variable by name. If doesn't exist return the default.
```
from prefect import variables

@flow
def my_flow():
var = variables.get("my_var")
```
or
```
from prefect import variables

@flow
async def my_flow():
var = await variables.get("my_var")
```
"""
variable = await Variable.get(name)
return variable.value if variable else default
@classmethod
@sync_compatible
async def unset(cls, name: str) -> bool:
"""
Unset a variable by name.

```
from prefect.variables import Variable

@flow
def my_flow():
Variable.unset("my_var")
```
or
```
from prefect.variables import Variable

@flow
async def my_flow():
await Variable.unset("my_var")
```
"""
client, _ = get_or_create_client()
try:
await client.delete_variable_by_name(name=name)
return True
except ObjectNotFound:
return False
27 changes: 27 additions & 0 deletions tests/client/test_prefect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,33 @@ async def variables(
results.append(res.json())
return pydantic.parse_obj_as(List[Variable], results)

@pytest.mark.parametrize(
"value",
[
"string-value",
'"string-value"',
123,
12.3,
True,
False,
None,
{"key": "value"},
["value1", "value2"],
{"key": ["value1", "value2"]},
],
)
async def test_create_variable(self, prefect_client, value):
created_variable = await prefect_client.create_variable(
variable=VariableCreate(name="my_variable", value=value)
)
assert created_variable
assert created_variable.name == "my_variable"
assert created_variable.value == value

res = await prefect_client.read_variable_by_name(created_variable.name)
assert res.name == created_variable.name
assert res.value == value

async def test_read_variable_by_name(self, prefect_client, variable):
res = await prefect_client.read_variable_by_name(variable.name)
assert res.name == variable.name
Expand Down
Loading
Loading