Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions examples/EmergencyManagement/Server/controllers_emergency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from dataclasses import asdict
from reticulum_openapi.controller import Controller, handle_exceptions
from examples.EmergencyManagement.Server.database import async_session
from examples.EmergencyManagement.Server.models_emergency import (
EmergencyActionMessage,
Event,
Expand All @@ -13,73 +15,71 @@ class EmergencyController(Controller):
@handle_exceptions
async def CreateEmergencyActionMessage(self, req: EmergencyActionMessage):
self.logger.info(f"CreateEAM: {req}")
await asyncio.sleep(0.1)
async with async_session() as session:
await EmergencyActionMessage.create(session, **asdict(req))
return req

@handle_exceptions
async def DeleteEmergencyActionMessage(self, callsign: str):
self.logger.info(f"DeleteEAM callsign={callsign}")
await asyncio.sleep(0.1)
return {"status": "deleted", "callsign": callsign}
async with async_session() as session:
deleted = await EmergencyActionMessage.delete(session, callsign)
return {"status": "deleted" if deleted else "not_found", "callsign": callsign}

@handle_exceptions
async def ListEmergencyActionMessage(self):
self.logger.info("ListEAM")
await asyncio.sleep(0.1)
return []
async with async_session() as session:
items = await EmergencyActionMessage.list(session)
return items

@handle_exceptions
async def PatchEmergencyActionMessage(self, req: EmergencyActionMessage):
self.logger.info(f"PatchEAM: {req}")
await asyncio.sleep(0.1)
return req
async with async_session() as session:
updated = await EmergencyActionMessage.update(session, req.callsign, **asdict(req))
return updated

@handle_exceptions
async def RetrieveEmergencyActionMessage(self, callsign: str):
self.logger.info(f"RetrieveEAM callsign={callsign}")
await asyncio.sleep(0.1)
return EmergencyActionMessage(
callsign=callsign, groupName="Alpha",
securityStatus=EAMStatus.Green, securityCapability=EAMStatus.Green,
preparednessStatus=EAMStatus.Green, medicalStatus=EAMStatus.Green,
mobilityStatus=EAMStatus.Green, commsStatus=EAMStatus.Green,
commsMethod="Radio"
)
async with async_session() as session:
item = await EmergencyActionMessage.get(session, callsign)
return item


class EventController(Controller):
@handle_exceptions
async def CreateEvent(self, req: Event):
self.logger.info(f"CreateEvent: {req}")
await asyncio.sleep(0.1)
async with async_session() as session:
await Event.create(session, **asdict(req))
return req

@handle_exceptions
async def DeleteEvent(self, uid: str):
self.logger.info(f"DeleteEvent uid={uid}")
await asyncio.sleep(0.1)
return {"status": "deleted", "uid": uid}
async with async_session() as session:
deleted = await Event.delete(session, int(uid))
return {"status": "deleted" if deleted else "not_found", "uid": uid}

@handle_exceptions
async def ListEvent(self):
self.logger.info("ListEvent")
await asyncio.sleep(0.1)
return []
async with async_session() as session:
events = await Event.list(session)
return events

@handle_exceptions
async def PatchEvent(self, req: Event):
self.logger.info(f"PatchEvent: {req}")
await asyncio.sleep(0.1)
return req
async with async_session() as session:
updated = await Event.update(session, req.uid, **asdict(req))
return updated

@handle_exceptions
async def RetrieveEvent(self, uid: str):
self.logger.info(f"RetrieveEvent uid={uid}")
await asyncio.sleep(0.1)
return Event(
uid=int(uid), how="m-g", version=1, time=0, type="Emergency",
stale="PT1H", start="PT0S", access="public",
opex=0, qos=1,
detail=Detail(emergencyActionMessage=None),
point=Point(0, 0, 0, 0, 0)
)
async with async_session() as session:
event = await Event.get(session, int(uid))
return event
11 changes: 11 additions & 0 deletions examples/EmergencyManagement/Server/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from .models_emergency import Base

DATABASE_URL = "sqlite+aiosqlite:///emergency.db"

engine = create_async_engine(DATABASE_URL, echo=False)
async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
35 changes: 35 additions & 0 deletions examples/EmergencyManagement/Server/models_emergency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,38 @@
from dataclasses import dataclass
from reticulum_openapi.model import BaseModel
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer, String, Float, JSON

Base = declarative_base()


class EmergencyActionMessageORM(Base):
__tablename__ = "emergency_action_messages"
callsign = Column(String, primary_key=True)
groupName = Column(String)
securityStatus = Column(String)
securityCapability = Column(String)
preparednessStatus = Column(String)
medicalStatus = Column(String)
mobilityStatus = Column(String)
commsStatus = Column(String)
commsMethod = Column(String)


class EventORM(Base):
__tablename__ = "events"
uid = Column(Integer, primary_key=True)
how = Column(String)
version = Column(Integer)
time = Column(Integer)
type = Column(String)
stale = Column(String)
start = Column(String)
access = Column(String)
opex = Column(Integer)
qos = Column(Integer)
detail = Column(JSON)
point = Column(JSON)


class EAMStatus(str):
Expand All @@ -19,6 +52,7 @@ class EmergencyActionMessage(BaseModel):
mobilityStatus: EAMStatus
commsStatus: EAMStatus
commsMethod: str
__orm_model__ = EmergencyActionMessageORM


@dataclass
Expand Down Expand Up @@ -49,3 +83,4 @@ class Event(BaseModel):
qos: int
detail: Detail
point: Point
__orm_model__ = EventORM
2 changes: 2 additions & 0 deletions examples/EmergencyManagement/Server/server_emergency.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
from examples.EmergencyManagement.Server.service_emergency import EmergencyService
from examples.EmergencyManagement.Server.database import init_db


async def main():
await init_db()
svc = EmergencyService()
svc.announce()
service_task = asyncio.create_task(svc.start())
Expand Down
57 changes: 42 additions & 15 deletions reticulum_openapi/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
# reticulum_openapi/model.py
from dataclasses import dataclass, asdict, is_dataclass
from dataclasses import dataclass, asdict, is_dataclass, fields
import json
import zlib
from typing import Type, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import select

__all__ = [
"dataclass_to_json",
"dataclass_from_json",
"BaseModel",
"create_async_engine",
"async_sessionmaker",
]

T = TypeVar('T')


Expand Down Expand Up @@ -59,18 +67,33 @@ def from_json_bytes(cls: Type[T], data: bytes) -> T:
"""Deserialize compressed JSON bytes to a dataclass instance."""
return dataclass_from_json(cls, data)

def to_orm(self):
"""Create an ORM instance from this dataclass."""
if self.__orm_model__ is None:
raise NotImplementedError(
"Subclasses must define __orm_model__ for persistence"
)
return self.__orm_model__(**asdict(self))

@classmethod
def from_orm(cls: Type[T], orm_obj) -> T:
"""Instantiate a dataclass from an ORM row."""
kwargs = {f.name: getattr(orm_obj, f.name) for f in fields(cls)}
return cls(**kwargs)

@classmethod
async def create(cls, session: AsyncSession, **kwargs):
async def create(cls, session: AsyncSession, **kwargs) -> T:
"""
Create and persist a new record using the associated ORM model.
Returns the ORM instance.
Returns the dataclass instance.
"""
if cls.__orm_model__ is None:
raise NotImplementedError("Subclasses must define __orm_model__ for persistence")
obj = cls.__orm_model__(**kwargs)
session.add(obj)
await session.commit()
return obj
await session.refresh(obj)
return cls.from_orm(obj)

@classmethod
async def get(cls, session: AsyncSession, id_):
Expand All @@ -80,7 +103,10 @@ async def get(cls, session: AsyncSession, id_):
"""
if cls.__orm_model__ is None:
raise NotImplementedError("Subclasses must define __orm_model__ for persistence")
return await session.get(cls.__orm_model__, id_)
orm_obj = await session.get(cls.__orm_model__, id_)
if orm_obj is None:
return None
return cls.from_orm(orm_obj)

@classmethod
async def list(cls, session: AsyncSession, **filters):
Expand All @@ -95,7 +121,7 @@ async def list(cls, session: AsyncSession, **filters):
for attr, value in filters.items():
stmt = stmt.where(getattr(cls.__orm_model__, attr) == value)
result = await session.execute(stmt)
return result.scalars().all()
return [cls.from_orm(obj) for obj in result.scalars().all()]

@classmethod
async def update(cls, session: AsyncSession, id_, **kwargs):
Expand All @@ -105,14 +131,15 @@ async def update(cls, session: AsyncSession, id_, **kwargs):
"""
if cls.__orm_model__ is None:
raise NotImplementedError("Subclasses must define __orm_model__ for persistence")
obj = await cls.get(session, id_)
if obj is None:
orm_obj = await session.get(cls.__orm_model__, id_)
if orm_obj is None:
return None
for attr, value in kwargs.items():
setattr(obj, attr, value)
session.add(obj)
setattr(orm_obj, attr, value)
session.add(orm_obj)
await session.commit()
return obj
await session.refresh(orm_obj)
return cls.from_orm(orm_obj)

@classmethod
async def delete(cls, session: AsyncSession, id_):
Expand All @@ -122,9 +149,9 @@ async def delete(cls, session: AsyncSession, id_):
"""
if cls.__orm_model__ is None:
raise NotImplementedError("Subclasses must define __orm_model__ for persistence")
obj = await cls.get(session, id_)
if obj is None:
orm_obj = await session.get(cls.__orm_model__, id_)
if orm_obj is None:
return False
await session.delete(obj)
await session.delete(orm_obj)
await session.commit()
return True
39 changes: 39 additions & 0 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from dataclasses import dataclass
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession

from reticulum_openapi.model import BaseModel

Base = declarative_base()

class ItemORM(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True)
name = Column(String)

@dataclass
class Item(BaseModel):
id: int
name: str
__orm_model__ = ItemORM

@pytest.mark.asyncio
async def test_crud_roundtrip():
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with async_session() as session:
await Item.create(session, id=1, name="foo")
item = await Item.get(session, 1)
assert item.name == "foo"
await Item.update(session, 1, name="bar")
updated = await Item.get(session, 1)
assert updated.name == "bar"
items = await Item.list(session)
assert len(items) == 1
assert items[0].name == "bar"
deleted = await Item.delete(session, 1)
assert deleted
Loading