|
6 | 6 | from fastapi import Path, Query, Body |
7 | 7 |
|
8 | 8 |
|
| 9 | +async def _search_user_events_by_tags_internal( |
| 10 | + user_id: UUID, |
| 11 | + project_id: str, |
| 12 | + tags: str = None, |
| 13 | + tag_values: str = None, |
| 14 | + topk: int = 10, |
| 15 | +) -> res.UserEventsDataResponse: |
| 16 | + """Internal function to search user events by tags.""" |
| 17 | + has_event_tag = None |
| 18 | + if tags: |
| 19 | + has_event_tag = [tag.strip() for tag in tags.split(",") if tag.strip()] |
| 20 | + |
| 21 | + event_tag_equal = None |
| 22 | + if tag_values: |
| 23 | + event_tag_equal = {} |
| 24 | + for pair in tag_values.split(","): |
| 25 | + if "=" in pair: |
| 26 | + tag_name, tag_value = pair.split("=", 1) |
| 27 | + event_tag_equal[tag_name.strip()] = tag_value.strip() |
| 28 | + |
| 29 | + p = await controllers.event.filter_user_events( |
| 30 | + user_id, project_id, has_event_tag, event_tag_equal, topk |
| 31 | + ) |
| 32 | + |
| 33 | + return p.to_response(res.UserEventsDataResponse) |
| 34 | + |
| 35 | + |
9 | 36 | async def get_user_events( |
10 | 37 | request: Request, |
11 | 38 | user_id: UUID = Path(..., description="The ID of the user"), |
@@ -66,10 +93,19 @@ async def search_user_events( |
66 | 93 | use_gists: bool = Query( |
67 | 94 | True, description="Whether to search event gists (default) or event tip" |
68 | 95 | ), |
69 | | -) -> res.UserEventGistsDataResponse |res.UserEventsDataResponse: |
| 96 | + use_tag: bool = Query( |
| 97 | + False, description="Whether to search by tags instead of query" |
| 98 | + ), |
| 99 | + tags: str = Query(None, description="Comma-separated list of tag names that events must have (e.g.'emotion,romance')"), |
| 100 | + tag_values: str = Query(None, description="Comma-separated tag=value pairs for exact matches (e.g., 'emotion=happy,topic=work')"), |
| 101 | +) -> res.UserEventGistsDataResponse | res.UserEventsDataResponse: |
70 | 102 | project_id = request.state.memobase_project_id |
71 | 103 |
|
72 | | - if use_gists: |
| 104 | + if use_tag: |
| 105 | + return await _search_user_events_by_tags_internal( |
| 106 | + user_id, project_id, tags, tag_values, topk |
| 107 | + ) |
| 108 | + elif use_gists: |
73 | 109 | p = await controllers.event_gist.search_user_event_gists( |
74 | 110 | user_id, project_id, query, topk, similarity_threshold, time_range_in_days |
75 | 111 | ) |
@@ -108,21 +144,6 @@ async def search_user_events_by_tags( |
108 | 144 | topk: int = Query(10, description="Number of events to retrieve, default is 10"), |
109 | 145 | ) -> res.UserEventsDataResponse: |
110 | 146 | project_id = request.state.memobase_project_id |
111 | | - |
112 | | - has_event_tag = None |
113 | | - if tags: |
114 | | - has_event_tag = [tag.strip() for tag in tags.split(",") if tag.strip()] |
115 | | - |
116 | | - event_tag_equal = None |
117 | | - if tag_values: |
118 | | - event_tag_equal = {} |
119 | | - for pair in tag_values.split(","): |
120 | | - if "=" in pair: |
121 | | - tag_name, tag_value = pair.split("=", 1) |
122 | | - event_tag_equal[tag_name.strip()] = tag_value.strip() |
123 | | - |
124 | | - p = await controllers.event.filter_user_events( |
125 | | - user_id, project_id, has_event_tag, event_tag_equal, topk |
| 147 | + return await _search_user_events_by_tags_internal( |
| 148 | + user_id, project_id, tags, tag_values, topk |
126 | 149 | ) |
127 | | - |
128 | | - return p.to_response(res.UserEventsDataResponse) |
|
0 commit comments