diff --git a/src/check_jsonschema/schema_loader/main.py b/src/check_jsonschema/schema_loader/main.py index 3b10c3b01..63c471934 100644 --- a/src/check_jsonschema/schema_loader/main.py +++ b/src/check_jsonschema/schema_loader/main.py @@ -100,8 +100,8 @@ def _get_schema_reader(self) -> LocalSchemaReader | HttpSchemaReader: f"detected parsed URL had an unrecognized scheme: {self.url_info}" ) - def get_schema_ref_base(self) -> str | None: - return self.reader.get_ref_base() + def get_schema_retrieval_uri(self) -> str | None: + return self.reader.get_retrieval_uri() def get_schema(self) -> dict[str, t.Any]: return self.reader.read_schema() @@ -113,7 +113,7 @@ def get_validator( format_opts: FormatOptions, fill_defaults: bool, ) -> jsonschema.Validator: - schema_uri = self.get_schema_ref_base() + retrieval_uri = self.get_schema_retrieval_uri() schema = self.get_schema() schema_dialect = schema.get("$schema") @@ -123,7 +123,7 @@ def get_validator( # reference resolution # with support for YAML, TOML, and other formats from the parsers - reference_registry = make_reference_registry(self._parsers, schema_uri, schema) + reference_registry = make_reference_registry(self._parsers, retrieval_uri, schema) # get the correct validator class and check the schema under its metaschema validator_cls = jsonschema.validators.validator_for(schema) @@ -147,7 +147,7 @@ def __init__(self, schema_name: str) -> None: self.schema_name = schema_name self._parsers = ParserSet() - def get_schema_ref_base(self) -> str | None: + def get_schema_retrieval_uri(self) -> str | None: return None def get_schema(self) -> dict[str, t.Any]: diff --git a/src/check_jsonschema/schema_loader/readers.py b/src/check_jsonschema/schema_loader/readers.py index 98e24c011..c2e469781 100644 --- a/src/check_jsonschema/schema_loader/readers.py +++ b/src/check_jsonschema/schema_loader/readers.py @@ -30,7 +30,7 @@ def __init__(self, filename: str) -> None: self.filename = str(self.path) self.parsers = ParserSet() - def get_ref_base(self) -> str: + def get_retrieval_uri(self) -> str: return self.path.as_uri() def _read_impl(self) -> t.Any: @@ -55,7 +55,7 @@ def __init__( validation_callback=json.loads, ) - def get_ref_base(self) -> str: + def get_retrieval_uri(self) -> str: return self.url def _read_impl(self) -> t.Any: diff --git a/src/check_jsonschema/schema_loader/resolver.py b/src/check_jsonschema/schema_loader/resolver.py index 5d658eadc..b227b5da9 100644 --- a/src/check_jsonschema/schema_loader/resolver.py +++ b/src/check_jsonschema/schema_loader/resolver.py @@ -12,21 +12,25 @@ def make_reference_registry( - parsers: ParserSet, schema_uri: str | None, schema: dict + parsers: ParserSet, retrieval_uri: str | None, schema: dict ) -> referencing.Registry: + id_attribute_: t.Any = schema.get("$id") + if isinstance(id_attribute_, str): + id_attribute: str | None = id_attribute_ + else: + id_attribute = None + schema_resource = referencing.Resource.from_contents( schema, default_specification=DRAFT202012 ) # mypy does not recognize that Registry is an `attrs` class and has `retrieve` as an # argument to its implicit initializer registry: referencing.Registry = referencing.Registry( # type: ignore[call-arg] - retrieve=create_retrieve_callable(parsers, schema_uri) + retrieve=create_retrieve_callable(parsers, retrieval_uri, id_attribute) ) - if schema_uri is not None: - registry = registry.with_resource(uri=schema_uri, resource=schema_resource) - - id_attribute = schema.get("$id") + if retrieval_uri is not None: + registry = registry.with_resource(uri=retrieval_uri, resource=schema_resource) if id_attribute is not None: registry = registry.with_resource(uri=id_attribute, resource=schema_resource) @@ -34,16 +38,20 @@ def make_reference_registry( def create_retrieve_callable( - parser_set: ParserSet, schema_uri: str | None + parser_set: ParserSet, retrieval_uri: str | None, id_attribute: str | None ) -> t.Callable[[str], referencing.Resource[Schema]]: + base_uri = id_attribute + if base_uri is None: + base_uri = retrieval_uri + def get_local_file(uri: str) -> t.Any: path = filename2path(uri) return parser_set.parse_file(path, "json") def retrieve_reference(uri: str) -> referencing.Resource[Schema]: scheme = urllib.parse.urlsplit(uri).scheme - if scheme == "" and schema_uri is not None: - full_uri = urllib.parse.urljoin(schema_uri, uri) + if scheme == "" and base_uri is not None: + full_uri = urllib.parse.urljoin(base_uri, uri) else: full_uri = uri