Skip to content

Feature/performance test #749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions assets/fixtures.sql

This file was deleted.

73 changes: 73 additions & 0 deletions openadapt/a11y.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Union
from loguru import logger
import AppKit
import ApplicationServices

def get_attribute(element, attribute):
result, value = ApplicationServices.AXUIElementCopyAttributeValue(element, attribute, None)
if result == 0:
return value
return None

def find_element_by_attribute(element, attribute, value):
if get_attribute(element, attribute) == value:
return element
children = get_attribute(element, ApplicationServices.kAXChildrenAttribute) or []
for child in children:
found = find_element_by_attribute(child, attribute, value)
if found:
return found
return None

def find_application(app_name: str):
"""Find an application by its name and return its accessibility element.

Args:
app_name (str): The name of the application to find.

Returns:
AXUIElement or None: The AXUIElement representing the application,
or None if the application is not running.
"""
workspace = AppKit.NSWorkspace.sharedWorkspace()
running_apps = workspace.runningApplications()
app = next((app for app in running_apps if app.localizedName() == app_name), None)
if app is None:
logger.error(f"{app_name} application is not running.")
return None

app_element = ApplicationServices.AXUIElementCreateApplication(app.processIdentifier())
return app_element

def get_main_window(app_element):
"""Get the main window of an application.

Args:
app_element: The AXUIElement of the application.

Returns:
AXUIElement or None: The AXUIElement representing the main window,
or None if no windows are found.
"""
error_code, windows = ApplicationServices.AXUIElementCopyAttributeValue(app_element, ApplicationServices.kAXWindowsAttribute, None)
if error_code or not windows:
return None

return windows[0] if windows else None

def get_element_value(element, role: str):
"""Get the value of a specific element by its role.

Args:
element: The AXUIElement to search within.
role (str): The role of the element to find (e.g., "AXStaticText").

Returns:
str: The value of the element, or an error message if not found.
"""
target_element = find_element_by_attribute(element, ApplicationServices.kAXRoleAttribute, role)
if not target_element:
return f"{role} element not found."

value = get_attribute(target_element, ApplicationServices.kAXValueAttribute)
return value if value else f"No value for {role} element."
4 changes: 3 additions & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"
DB_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute()

STOP_STRS = [
"oa.stop",
Expand Down Expand Up @@ -124,7 +125,8 @@ class SegmentationAdapter(str, Enum):

# Database
DB_ECHO: bool = False
DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}"
DB_FILE_PATH: str = str(DB_FILE_PATH)
DB_URL: ClassVar[str] = f"sqlite:///{DB_FILE_PATH}"

# Error reporting
ERROR_REPORTING_ENABLED: bool = True
Expand Down
30 changes: 26 additions & 4 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import time

from typing import List
from loguru import logger
from sqlalchemy.orm import Session as SaSession
import psutil
Expand Down Expand Up @@ -279,21 +280,25 @@ def delete_recording(session: SaSession, recording: Recording) -> None:
delete_video_file(recording_timestamp)


def get_all_recordings(session: SaSession) -> list[Recording]:
"""Get all recordings.
def get_recordings(session: SaSession, max_rows=None) -> list[Recording]:
"""Get recordings.

Args:
session (sa.orm.Session): The database session.
max_rows: The number of recordings to return, starting from the most recent.
Defaults to all if max_rows is not specified.

Returns:
list[Recording]: A list of all original recordings.
"""
return (
query = (
session.query(Recording)
.filter(Recording.original_recording_id == None) # noqa: E711
.order_by(sa.desc(Recording.timestamp))
.all()
)
if max_rows:
query = query.limit(max_rows)
return query.all()


def get_all_scrubbed_recordings(
Expand Down Expand Up @@ -350,6 +355,23 @@ def get_recording(session: SaSession, timestamp: float) -> Recording:
return session.query(Recording).filter(Recording.timestamp == timestamp).first()


def get_recordings_by_desc(session: SaSession, description_str: str) -> List[Recording]:
"""Get recordings by task description.

Args:
session (sa.orm.Session): The database session.
task_description (str): The task description to search for.

Returns:
List[Recording]: A list of recordings whose task descriptions contain the given string.
"""
return (
session.query(Recording)
.filter(Recording.task_description.contains(description_str))
.all()
)


BaseModelType = TypeVar("BaseModelType")


Expand Down
4 changes: 2 additions & 2 deletions openadapt/scripts/reset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

def reset_db() -> None:
"""Clears the database by removing the db file and running a db migration."""
if os.path.exists(config.DB_FPATH):
os.remove(config.DB_FPATH)
if os.path.exists(config.DB_FILE_PATH):
os.remove(config.DB_FILE_PATH)

# Prevents duplicate logging of config values by piping stderr
# and filtering the output.
Expand Down
123 changes: 123 additions & 0 deletions scripts/generate_db_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import sessionmaker
from openadapt.db.db import Base
from openadapt.config import DATA_DIR_PATH, PARENT_DIR_PATH, RECORDING_DIR_PATH
import openadapt.db.crud as crud

def get_session():
db_url = RECORDING_DIR_PATH / "recording.db"
print(f"Database URL: {db_url}")
engine = create_engine(f"sqlite:///{db_url}")
# SessionLocal = sessionmaker(bind=engine)
Base.metadata.create_all(bind=engine)
session = crud.get_new_session(read_only=True)
print("Database connection established.")
return session, engine

def check_tables_exist(engine):
inspector = inspect(engine)
tables = inspector.get_table_names()
expected_tables = [
'recording',
'action_event',
'screenshot',
'window_event',
'performance_stat',
'memory_stat'
]
for table in expected_tables:
if table in tables:
print(f"Table '{table}' exists.")
else:
print(f"Table '{table}' does NOT exist.")
return tables

def fetch_data(session):
# get the most recent three recordings
recordings = crud.get_recordings(session, max_rows=3)
recording_ids = [recording.id for recording in recordings]

action_events = []
screenshots = []
window_events = []
performance_stats = []
memory_stats = []

for recording in recordings:
action_events.extend(crud.get_action_events(session, recording))
screenshots.extend(crud.get_screenshots(session, recording))
window_events.extend(crud.get_window_events(session, recording))
performance_stats.extend(crud.get_perf_stats(session, recording))
memory_stats.extend(crud.get_memory_stats(session, recording))

data = {
"recordings": recordings,
"action_events": action_events,
"screenshots": screenshots,
"window_events": window_events,
"performance_stats": performance_stats,
"memory_stats": memory_stats,
}

# Debug prints to verify data fetching
print(f"Recordings: {len(data['recordings'])} found.")
print(f"Action Events: {len(data['action_events'])} found.")
print(f"Screenshots: {len(data['screenshots'])} found.")
print(f"Window Events: {len(data['window_events'])} found.")
print(f"Performance Stats: {len(data['performance_stats'])} found.")
print(f"Memory Stats: {len(data['memory_stats'])} found.")

return data

def format_sql_insert(table_name, rows):
Copy link
Member

@abrichr abrichr Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add types.

Also, what do you think about something like this:

from sqlalchemy import insert, Table, MetaData
from sqlalchemy.engine import Engine  # Assuming you have an engine instance

def generate_sql_insert(engine: Engine, table_name: str, rows: list) -> str:
    """Generates a SQL INSERT statement using SQLAlchemy for given rows and table.

    Args:
        engine (Engine): SQLAlchemy Engine connected to the database.
        table_name (str): Name of the table to insert data into.
        rows (list): List of dictionaries representing the rows to insert.

    Returns:
        str: A string representation of the SQL INSERT statement suitable for fixtures.sql.
    """
    metadata = MetaData(bind=engine)
    table = Table(table_name, metadata, autoload_with=engine)

    stmt = insert(table).values(rows)
    compiled = stmt.compile(engine, compile_kwargs={"literal_binds": True})
    return str(compiled) + ";"

Please fill out the type for rows, e.g. list[dict], list[BaseModelType] , list[db.Base], or similar.

if not rows:
return ""

columns = rows[0].__table__.columns.keys()
sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES\n"
values = []

for row in rows:
row_values = [getattr(row, col) for col in columns]
row_values = [f"'{value}'" if isinstance(value, str) else str(value) for value in row_values]
values.append(f"({', '.join(row_values)})")

sql += ",\n".join(values) + ";\n"
return sql

def dump_to_fixtures(filepath):
session, engine = get_session()
check_tables_exist(engine)
data = fetch_data(session)

with open(filepath, "a") as file:
if data["recordings"]:
file.write("-- Insert sample recordings\n")
file.write(format_sql_insert("recording", data["recordings"]))

if data["action_events"]:
file.write("-- Insert sample action_events\n")
file.write(format_sql_insert("action_event", data["action_events"]))

if data["screenshots"]:
file.write("-- Insert sample screenshots\n")
file.write(format_sql_insert("screenshot", data["screenshots"]))

if data["window_events"]:
file.write("-- Insert sample window_events\n")
file.write(format_sql_insert("window_event", data["window_events"]))

if data["performance_stats"]:
file.write("-- Insert sample performance_stats\n")
file.write(format_sql_insert("performance_stat", data["performance_stats"]))

if data["memory_stats"]:
file.write("-- Insert sample memory_stats\n")
file.write(format_sql_insert("memory_stat", data["memory_stats"]))
Copy link
Member

@abrichr abrichr Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something more DRY, e.g.:

rows_by_table_name = fetch_data(session)
for table_name, rows in rows_by_table_name.items():
    if not rows:
        logger.warning(f"no rows for table_name=}")
        continue
    with open(file_path, "a") as file:
        logger.info(f"writing {len(rows)=} to {file_path=} for {table_name=}")
        file.write(f"-- Insert sample rows for {table_name}\n")
        file.write(format_sql_insert(table_name, rows))


print(f"Data appended to {filepath}")

if __name__ == "__main__":
fixtures_path = PARENT_DIR_PATH / "tests/assets/fixtures.sql"
dump_to_fixtures(fixtures_path)
Loading