Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 8 additions & 27 deletions data_rentgen/db/repositories/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
asc,
bindparam,
desc,
func,
select,
)
from sqlalchemy.orm import selectinload

from data_rentgen.db.models import Tag, TagValue
from data_rentgen.db.models import Tag
from data_rentgen.db.repositories.base import Repository
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
from data_rentgen.dto.pagination import PaginationDTO
Expand All @@ -36,44 +34,27 @@ async def paginate(
tag_ids: Collection[int],
search_query: str | None,
) -> PaginationDTO[Tag]:
where = []
if tag_ids:
where.append(Tag.id == any_(list(tag_ids))) # type: ignore[arg-type]

query: Select | CompoundSelect
order_by: list[ColumnElement | SQLColumnExpression]
if search_query:
tsquery = make_tsquery(search_query)

tag_stmt = select(Tag.id, Tag.name, ts_rank(Tag.search_vector, tsquery).label("search_rank")).where(
ts_match(Tag.search_vector, tsquery),
*where,
)
value_stmt = (
select(Tag.id, Tag.name, ts_rank(TagValue.search_vector, tsquery).label("search_rank"))
.join(TagValue, TagValue.tag_id == Tag.id)
.where(ts_match(TagValue.search_vector, tsquery), *where)
)
union_cte = tag_stmt.union_all(value_stmt).cte("tag_union")
query = select(
union_cte.c.id,
union_cte.c.name,
func.max(union_cte.c.search_rank).label("search_rank"),
).group_by(union_cte.c.id, union_cte.c.name)
Tag.id,
Tag.name,
ts_rank(Tag.search_vector, tsquery).label("search_rank"),
).where(ts_match(Tag.search_vector, tsquery))

order_by = [desc("search_rank"), asc("name")]
else:
query = select(Tag).where(*where)
query = select(Tag)
order_by = [Tag.name]

options = [
selectinload(Tag.tag_values),
]
if tag_ids:
query = query.where(Tag.id == any_(list(tag_ids))) # type: ignore[arg-type]

return await self._paginate_by_query(
query=query,
order_by=order_by,
options=options,
page=page,
page_size=page_size,
)
Expand Down
51 changes: 51 additions & 0 deletions data_rentgen/db/repositories/tag_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from collections.abc import Collection

from sqlalchemy import (
ARRAY,
ColumnElement,
CompoundSelect,
Integer,
Select,
SQLColumnExpression,
String,
any_,
asc,
bindparam,
cast,
desc,
func,
select,
tuple_,
)

from data_rentgen.db.models.tag_value import TagValue
from data_rentgen.db.repositories.base import Repository
from data_rentgen.db.utils.search import make_tsquery, ts_match, ts_rank
from data_rentgen.dto.pagination import PaginationDTO
from data_rentgen.dto.tag import TagValueDTO

fetch_bulk_query = select(TagValue).where(
Expand All @@ -33,6 +44,46 @@


class TagValueRepository(Repository[TagValue]):
async def paginate(
self,
page: int,
page_size: int,
tag_id: int | None,
tag_value_ids: Collection[int],
search_query: str | None,
) -> PaginationDTO[TagValue]:
query: Select | CompoundSelect
order_by: list[ColumnElement | SQLColumnExpression]
if search_query:
tsquery = make_tsquery(search_query)
query = (
select(
TagValue.id,
TagValue.tag_id,
TagValue.value,
ts_rank(TagValue.search_vector, tsquery).label("search_rank"),
)
.where(ts_match(TagValue.search_vector, tsquery))
.order_by(desc("search_rank"))
)
order_by = [desc("search_rank"), asc("value")]
else:
query = select(TagValue)
order_by = [TagValue.value]

if tag_id is not None:
query = query.where(TagValue.tag_id == tag_id)

if tag_value_ids:
query = query.where(TagValue.id == any_(list(tag_value_ids))) # type: ignore[arg-type]

return await self._paginate_by_query(
query=query,
order_by=order_by,
page=page,
page_size=page_size,
)

async def fetch_bulk(self, tag_values_dto: list[TagValueDTO]) -> list[tuple[TagValueDTO, TagValue | None]]:
if not tag_values_dto:
return []
Expand Down
2 changes: 2 additions & 0 deletions data_rentgen/server/api/v1/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from data_rentgen.server.api.v1.router.personal_token import router as personal_token_router
from data_rentgen.server.api.v1.router.run import router as run_router
from data_rentgen.server.api.v1.router.tag import router as tag_router
from data_rentgen.server.api.v1.router.tag_value import router as tag_value_router
from data_rentgen.server.api.v1.router.user import router as user_router

router = APIRouter(prefix="/v1")
Expand All @@ -22,3 +23,4 @@
router.include_router(user_router)
router.include_router(personal_token_router)
router.include_router(tag_router)
router.include_router(tag_value_router)
2 changes: 1 addition & 1 deletion data_rentgen/server/api/v1/router/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from data_rentgen.server.schemas.v1 import (
PageResponseV1,
TagDetailedResponseV1,
TagPaginateQueryV1,
)
from data_rentgen.server.schemas.v1.tag import TagPaginateQueryV1
from data_rentgen.server.services import get_user
from data_rentgen.server.services.tag import TagService

Expand Down
37 changes: 37 additions & 0 deletions data_rentgen/server/api/v1/router/tag_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2025-present MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from typing import Annotated

from fastapi import APIRouter, Depends, Query

from data_rentgen.db.models.user import User
from data_rentgen.server.errors import get_error_responses
from data_rentgen.server.errors.schemas import InvalidRequestSchema, NotAuthorizedRedirectSchema, NotAuthorizedSchema
from data_rentgen.server.schemas.v1 import (
PageResponseV1,
TagValueDetailedResponseV1,
TagValuePaginateQueryV1,
)
from data_rentgen.server.services import TagValueService, get_user

router = APIRouter(
prefix="/tag-values",
tags=["TagValues"],
responses=get_error_responses(include={NotAuthorizedSchema, NotAuthorizedRedirectSchema, InvalidRequestSchema}),
)


@router.get("", summary="Paginated list of TagValues")
async def paginate_tag_values(
query_args: Annotated[TagValuePaginateQueryV1, Query()],
tag_value_service: Annotated[TagValueService, Depends()],
current_user: Annotated[User, Depends(get_user())],
) -> PageResponseV1[TagValueDetailedResponseV1]:
pagination = await tag_value_service.paginate(
page=query_args.page,
page_size=query_args.page_size,
tag_id=query_args.tag_id,
tag_value_ids=query_args.tag_value_id,
search_query=query_args.search_query,
)
return PageResponseV1[TagValueDetailedResponseV1].from_pagination(pagination)
18 changes: 17 additions & 1 deletion data_rentgen/server/schemas/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,16 @@
RunsPaginateQueryV1,
RunStatisticsReponseV1,
)
from data_rentgen.server.schemas.v1.tag import TagDetailedResponseV1
from data_rentgen.server.schemas.v1.tag import (
NestedTagValueResponseV1,
TagDetailedResponseV1,
TagPaginateQueryV1,
TagResponseV1,
TagValueDetailedResponseV1,
TagValuePaginateQueryV1,
TagValueResponseV1,
TagWithValuesResponseV1,
)
from data_rentgen.server.schemas.v1.user import UserResponseV1

__all__ = [
Expand Down Expand Up @@ -103,6 +112,7 @@
"LocationDetailedResponseV1",
"LocationPaginateQueryV1",
"LocationResponseV1",
"NestedTagValueResponseV1",
"OperationDetailedResponseV1",
"OperationIOStatisticsReponseV1",
"OperationLineageQueryV1",
Expand All @@ -128,6 +138,12 @@
"RunStatisticsReponseV1",
"RunsPaginateQueryV1",
"TagDetailedResponseV1",
"TagPaginateQueryV1",
"TagResponseV1",
"TagValueDetailedResponseV1",
"TagValuePaginateQueryV1",
"TagValueResponseV1",
"TagWithValuesResponseV1",
"UpdateLocationRequestV1",
"UserResponseV1",
]
4 changes: 2 additions & 2 deletions data_rentgen/server/schemas/v1/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from data_rentgen.server.schemas.v1.location import LocationResponseV1
from data_rentgen.server.schemas.v1.pagination import PaginateQueryV1
from data_rentgen.server.schemas.v1.tag import TagResponseV1
from data_rentgen.server.schemas.v1.tag import TagWithValuesResponseV1


class DatasetSchemaFieldV1(BaseModel):
Expand Down Expand Up @@ -47,7 +47,7 @@ class DatasetResponseV1(BaseModel):
class DatasetDetailedResponseV1(BaseModel):
id: str = Field(description="Dataset id", coerce_numbers_to_str=True)
data: DatasetResponseV1 = Field(description="Dataset data")
tags: list[TagResponseV1] = Field(default_factory=list, description="Dataset tags")
tags: list[TagWithValuesResponseV1] = Field(default_factory=list, description="Dataset tags")

model_config = ConfigDict(from_attributes=True)

Expand Down
4 changes: 2 additions & 2 deletions data_rentgen/server/schemas/v1/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from data_rentgen.server.schemas.v1.location import LocationResponseV1
from data_rentgen.server.schemas.v1.pagination import PaginateQueryV1
from data_rentgen.server.schemas.v1.tag import TagResponseV1
from data_rentgen.server.schemas.v1.tag import TagWithValuesResponseV1


class JobResponseV1(BaseModel):
Expand All @@ -23,7 +23,7 @@ class JobResponseV1(BaseModel):
class JobDetailedResponseV1(BaseModel):
id: str = Field(description="Job id", coerce_numbers_to_str=True)
data: JobResponseV1 = Field(description="Job data")
tags: list[TagResponseV1] = Field(default_factory=list, description="Job tags")
tags: list[TagWithValuesResponseV1] = Field(default_factory=list, description="Job tags")

model_config = ConfigDict(from_attributes=True)

Expand Down
62 changes: 57 additions & 5 deletions data_rentgen/server/schemas/v1/tag.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
# SPDX-FileCopyrightText: 2025-present MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from data_rentgen.server.schemas.v1.pagination import PaginateQueryV1


class TagValueResponseV1(BaseModel):
class NestedTagValueResponseV1(BaseModel):
id: int = Field(description="Tag value id")
value: str = Field(description="Tag value")

model_config = ConfigDict(from_attributes=True)


class TagWithValuesResponseV1(BaseModel):
id: int = Field(description="Tag id")
name: str = Field(description="Tag name")
values: list[NestedTagValueResponseV1] = Field(default_factory=list, description="Values for the tag")

model_config = ConfigDict(from_attributes=True)


class TagResponseV1(BaseModel):
id: int = Field(description="Tag id")
name: str = Field(description="Tag name")
values: list[TagValueResponseV1] = Field(default_factory=list, description="Values for the tag")

model_config = ConfigDict(from_attributes=True)


class TagValueResponseV1(BaseModel):
id: int = Field(description="TagValue id")
tag_id: int = Field(description="Tag id")
value: str = Field(description="Tag value")

model_config = ConfigDict(from_attributes=True)

Expand All @@ -27,6 +42,13 @@ class TagDetailedResponseV1(BaseModel):
model_config = ConfigDict(from_attributes=True)


class TagValueDetailedResponseV1(BaseModel):
id: int = Field(description="TagValue id")
data: TagValueResponseV1 = Field(description="TagValue data")

model_config = ConfigDict(from_attributes=True)


class TagPaginateQueryV1(PaginateQueryV1):
"""Query params for Tag paginate request."""

Expand All @@ -38,8 +60,38 @@ class TagPaginateQueryV1(PaginateQueryV1):
search_query: str | None = Field(
default=None,
min_length=3,
description="Search query, partial match by tag name or any value",
examples=["my tag"],
description="Search query, partial match by tag name",
examples=["my.tag"],
)

model_config = ConfigDict(extra="forbid")


class TagValuePaginateQueryV1(PaginateQueryV1):
"""Query params for TagValue paginate request."""

tag_id: int | None = Field(
default=None,
description="Get only values for specific tag",
examples=[[123]],
)
tag_value_id: list[int] = Field(
default_factory=list,
description="Ids of tag_values to fetch specific items only",
examples=[[123]],
)
search_query: str | None = Field(
default=None,
min_length=3,
description="Search query, partial match by tag value",
examples=["my value"],
)

@model_validator(mode="after")
def _check_tag_id_and_tag_value_id(self):
if not self.tag_id and not self.tag_value_id:
msg = "input should contain either 'tag_id' or 'tag_value_id' field"
raise ValueError(msg)
return self

model_config = ConfigDict(extra="forbid")
4 changes: 4 additions & 0 deletions data_rentgen/server/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from data_rentgen.server.services.operation import OperationService
from data_rentgen.server.services.personal_token import PersonalTokenService
from data_rentgen.server.services.run import RunService
from data_rentgen.server.services.tag import TagService
from data_rentgen.server.services.tag_value import TagValueService

__all__ = [
"DatasetService",
Expand All @@ -18,5 +20,7 @@
"PersonalTokenPolicy",
"PersonalTokenService",
"RunService",
"TagService",
"TagValueService",
"get_user",
]
Loading