Skip to content

Commit

Permalink
Updates: Store emotion predictions (BasedHardware#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
beastoin authored Aug 23, 2024
2 parents 7062172 + 40b0c58 commit 9733f02
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
29 changes: 27 additions & 2 deletions backend/database/memories.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import uuid
from typing import List
from datetime import datetime

from google.cloud import firestore
from google.cloud.firestore_v1 import FieldFilter

from models.memory import MemoryPhoto, PostProcessingStatus, PostProcessingModel
import utils.other.hume as hume
from models.transcript_segment import TranscriptSegment
from ._client import db

Expand All @@ -26,8 +28,6 @@ def get_memory(uid, memory_id):
return memory_ref.get().to_dict()




def get_memories(uid: str, limit: int = 100, offset: int = 0, include_discarded: bool = False):
memories_ref = (
db.collection('users').document(uid).collection('memories')
Expand Down Expand Up @@ -136,3 +136,28 @@ def store_model_segments_result(uid: str, memory_id: str, model_name: str, segme
batch.commit()
batch = db.batch()
batch.commit()

def store_model_emotion_predictions_result(uid: str, memory_id: str, model_name: str, predictions: List[hume.HumeJobModelPredictionResponseModel]):
now = datetime.now()
user_ref = db.collection('users').document(uid)
memory_ref = user_ref.collection('memories').document(memory_id)
predictions_ref = memory_ref.collection(model_name)
batch = db.batch()
count = 1
for prediction in predictions:
prediction_id = str(uuid.uuid4())
prediction_ref = predictions_ref.document(prediction_id)
batch.set(prediction_ref, {
"created_at": now,
"start": prediction.time[0],
"end": prediction.time[1],
})
emotions_ref = prediction_ref.collection("emotions")
for emotion in prediction.emotions:
emotion_ref = emotions_ref.document(emotion.name)
batch.set(emotion_ref, emotion.to_dict())
count += 1
if count % 400 == 0:
batch.commit()
batch = db.batch()
batch.commit()
4 changes: 4 additions & 0 deletions backend/utils/memories/process_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def process_user_expression_measurement_callback(provider: str, request_id: str,

uid = task.user_uid

# Save predictions
if len(callback.predictions) > 0:
memories_db.store_model_emotion_predictions_result(task.user_uid, task.memory_id, provider, callback.predictions)

# Memory
memory_data = memories_db.get_memory(uid, task.memory_id)
if memory_data is None:
Expand Down
8 changes: 7 additions & 1 deletion backend/utils/other/hume.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ def from_dict(cls, data: dict) -> "HumePredictionEmotionResponseModel":
model = cls(data["name"], data["score"])
return model

def to_dict(self):
return {
'name': self.name,
'score': self.score,
}


class HumeJobModelPredictionResponseModel:
def __init__(
self,
time,
emotions=[],
emotions: [HumePredictionEmotionResponseModel] = [],
) -> None:
self.emotions = emotions
self.time = time
Expand Down

0 comments on commit 9733f02

Please sign in to comment.