Skip to content

Fix class detection for namespaced classes (Py) #897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion schema_salad/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .typescript_codegen import TypeScriptCodeGen
from .utils import aslist

FIELD_SORT_ORDER = ["id", "class", "name"]
FIELD_SORT_ORDER = ["class", "id", "name"]


def codegen(
Expand Down
28 changes: 28 additions & 0 deletions schema_salad/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,8 @@ class RecordField(Documented):
A field of a record.
"""

class_uri = "https://w3id.org/cwl/salad#RecordField"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -1428,6 +1430,8 @@ def save(


class RecordSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#RecordSchema"

def __init__(
self,
type_: Any,
Expand Down Expand Up @@ -1632,6 +1636,8 @@ class EnumSchema(Saveable):

"""

class_uri = "https://w3id.org/cwl/salad#EnumSchema"

def __init__(
self,
symbols: Any,
Expand Down Expand Up @@ -1898,6 +1904,8 @@ def save(


class ArraySchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#ArraySchema"

def __init__(
self,
items: Any,
Expand Down Expand Up @@ -2097,6 +2105,8 @@ def save(


class MapSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#MapSchema"

def __init__(
self,
type_: Any,
Expand Down Expand Up @@ -2296,6 +2306,8 @@ def save(


class UnionSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#UnionSchema"

def __init__(
self,
names: Any,
Expand Down Expand Up @@ -2501,6 +2513,8 @@ class JsonldPredicate(Saveable):

"""

class_uri = "https://w3id.org/cwl/salad#JsonldPredicate"

def __init__(
self,
_id: Optional[Any] = None,
Expand Down Expand Up @@ -3239,6 +3253,8 @@ def save(


class SpecializeDef(Saveable):
class_uri = "https://w3id.org/cwl/salad#SpecializeDef"

def __init__(
self,
specializeFrom: Any,
Expand Down Expand Up @@ -3463,6 +3479,8 @@ class SaladRecordField(RecordField):
A field of a record.
"""

class_uri = "https://w3id.org/cwl/salad#SaladRecordField"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -3844,6 +3862,8 @@ def save(


class SaladRecordSchema(NamedType, RecordSchema, SchemaDefinedType):
class_uri = "https://w3id.org/cwl/salad#SaladRecordSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -4705,6 +4725,8 @@ class SaladEnumSchema(NamedType, EnumSchema, SchemaDefinedType):

"""

class_uri = "https://w3id.org/cwl/salad#SaladEnumSchema"

def __init__(
self,
symbols: Any,
Expand Down Expand Up @@ -5446,6 +5468,8 @@ class SaladMapSchema(NamedType, MapSchema, SchemaDefinedType):

"""

class_uri = "https://w3id.org/cwl/salad#SaladMapSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -6131,6 +6155,8 @@ class SaladUnionSchema(NamedType, UnionSchema, DocType):

"""

class_uri = "https://w3id.org/cwl/salad#SaladUnionSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -6757,6 +6783,8 @@ class Documentation(NamedType, DocType):

"""

class_uri = "https://w3id.org/cwl/salad#Documentation"

def __init__(
self,
name: Any,
Expand Down
51 changes: 22 additions & 29 deletions schema_salad/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def begin_class(
idfield: str,
optional_fields: set[str],
) -> None:
class_uri = classname
classname = self.safe_name(classname)

if extends:
Expand All @@ -163,6 +164,8 @@ def begin_class(
self.out.write(" pass\n\n\n")
return

self.out.write(f' class_uri = "{class_uri}"\n\n')

required_field_names = [f for f in field_names if f not in optional_fields]
optional_field_names = [f for f in field_names if f in optional_fields]

Expand Down Expand Up @@ -276,27 +279,6 @@ def save(
"""
)

if "class" in field_names:
self.out.write(
"""
if "class" not in _doc:
raise ValidationException("Missing 'class' field")
if _doc.get("class") != "{class_}":
raise ValidationException("tried `{class_}` but")

""".format(
class_=classname
)
)

self.serializer.write(
"""
r["class"] = "{class_}"
""".format(
class_=classname
)
)

def end_class(self, classname: str, field_names: list[str]) -> None:
"""Signal that we are done with this class."""
if self.current_class_is_abstract:
Expand Down Expand Up @@ -554,9 +536,6 @@ def declare_field(
if self.current_class_is_abstract:
return

if shortname(name) == "class":
return

if optional:
self.out.write(f""" {self.safe_name(name)} = None\n""")
self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907
Expand Down Expand Up @@ -608,8 +587,22 @@ def declare_field(
spc=spc,
)
)
self.out.write(
"""

if shortname(name) == "class":
self.out.write(
"""{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri:
{spc} raise ValidationException(f"tried `{{cls.__name__}}` but")
{spc} except ValidationException as e:
{spc} raise e
""".format(
safename=self.safe_name(name),
spc=spc,
)
)

else:
self.out.write(
"""
{spc} except ValidationException as e:
{spc} error_message, to_print, verb_tensage = parse_errors(str(e))

Expand Down Expand Up @@ -647,10 +640,10 @@ def declare_field(
{spc} )
{spc} )
""".format(
fieldname=shortname(name),
spc=spc,
fieldname=shortname(name),
spc=spc,
)
)
)

if name == self.idfield or not self.idfield:
baseurl = "base_url"
Expand Down
6 changes: 3 additions & 3 deletions schema_salad/tests/test_codegen_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None:
def test_error_message6(tmp_path: Path) -> None:
t = "test_schema/test6.cwl"
match = r"""\*\s+tried\s+`CommandLineTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`ExpressionTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`Workflow`\s+but
\s+Missing\s+'class'\s+field"""
\s+missing\s+required\s+field\s+`class`"""
path = get_data("tests/" + t)
assert path
with pytest.raises(ValidationException, match=match):
Expand Down
Loading