Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chats/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


@admin.display(description="Пользователи чата")
def chat_users(obj):
def chat_users(obj) -> str:
return f"{obj.get_users_str()}"


@admin.display(description="Количество сообщений")
def chat_message_count(obj):
def chat_message_count(obj) -> int:
return obj.messages.count()


Expand Down
14 changes: 8 additions & 6 deletions chats/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from django.contrib.auth import get_user_model
from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.authtoken.models import Token
from users.models import CustomUser
from django.contrib.auth.models import AnonymousUser


User = get_user_model()

Expand All @@ -22,10 +26,9 @@ class TokenAuthentication:

model = None

def get_model(self):
def get_model(self) -> Token:
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token

return Token

Expand All @@ -36,7 +39,7 @@ def get_model(self):
* user -- The user to which the token belongs
"""

def authenticate_credentials(self, key):
def authenticate_credentials(self, key: str) -> CustomUser:
model = self.get_model()
try:
token = model.objects.select_related("user").get(key=key)
Expand All @@ -48,7 +51,7 @@ def authenticate_credentials(self, key):

return token.user

def authenticate(self, token):
def authenticate(self, token: Token) -> CustomUser:
"""
Returns a `User` if a correct username and password have been supplied
Args:
Expand All @@ -71,14 +74,13 @@ def authenticate(self, token):


@database_sync_to_async
def get_user(scope):
def get_user(scope: dict) -> CustomUser | AnonymousUser:
"""
Return the user model instance associated with the given scope.
If no user is retrieved, return an instance of `AnonymousUser`.
"""
# postpone model import to avoid ImproperlyConfigured error before Django
# setup is complete.
from django.contrib.auth.models import AnonymousUser

if "token" not in scope:
raise ValueError(
Expand Down
44 changes: 23 additions & 21 deletions chats/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from abc import abstractmethod
from typing import List

Expand All @@ -7,6 +8,7 @@

from files.models import UserFile
from projects.models import Project
from users.models import CustomUser

User = get_user_model()

Expand All @@ -21,10 +23,10 @@ class BaseChat(models.Model):

created_at = models.DateTimeField(auto_now_add=True)

def get_last_message(self):
def get_last_message(self) -> BaseMessage:
return self.messages.last()

def get_users_str(self):
def get_users_str(self) -> str:
"""Returns string of users separated by a comma, who are in chat

Returns:
Expand All @@ -34,7 +36,7 @@ def get_users_str(self):
return ", ".join([user.get_full_name() for user in users])

@abstractmethod
def get_users(self):
def get_users(self) -> List[CustomUser]:
"""
Returns all collaborators and leader of the project.

Expand All @@ -44,7 +46,7 @@ def get_users(self):
pass

@abstractmethod
def get_avatar(self, user):
def get_avatar(self, user: CustomUser) -> str:
"""
Returns avatar of the chat for given user

Expand All @@ -57,7 +59,7 @@ def get_avatar(self, user):
pass

@abstractmethod
def get_last_messages(self, message_count):
def get_last_messages(self, message_count: int) -> List[BaseMessage]:
"""
Returns last messages of the chat

Expand Down Expand Up @@ -90,18 +92,18 @@ class ProjectChat(BaseChat):
Project, on_delete=models.CASCADE, related_name="project_chats"
)

def get_users(self):
def get_users(self) -> List[CustomUser]:
collaborators = self.project.collaborator_set.all()
users = [collaborator.user for collaborator in collaborators]
return users + [self.project.leader]

def get_avatar(self, user):
def get_avatar(self, user: CustomUser) -> str:
return self.project.image_address

def get_last_messages(self, message_count) -> List["BaseMessage"]:
def get_last_messages(self, message_count: int) -> List[BaseMessage]:
return self.messages.order_by("-created_at")[:message_count]

def __str__(self):
def __str__(self) -> str:
return f"ProjectChat<{self.project.id}> - {self.project.name}"

def save(
Expand Down Expand Up @@ -129,15 +131,15 @@ class DirectChat(BaseChat):
id = models.CharField(primary_key=True, max_length=64)
users = models.ManyToManyField(User, related_name="direct_chats")

def get_users(self):
def get_users(self) -> List[CustomUser]:
return self.users.all()

def get_avatar(self, user):
def get_avatar(self, user) -> str:
other_user = self.get_users().exclude(pk=user.pk).first()
return other_user.avatar

@classmethod
def get_chat(cls, user1, user2) -> "DirectChat":
def get_chat(cls, user1: CustomUser, user2: CustomUser) -> "DirectChat":
"""
Returns chat between two users.

Expand All @@ -157,25 +159,25 @@ def get_chat(cls, user1, user2) -> "DirectChat":
chat.users.set([user1, user2])
return chat

def get_last_messages(self, message_count):
def get_last_messages(self, message_count: int) -> BaseMessage:
return self.messages.order_by("-created_at")[:message_count]

def get_other_user(self, user) -> User:
def get_other_user(self, user: CustomUser) -> User:
return self.users.exclude(pk=user.pk).first()

@classmethod
def create_from_two_users(cls, user1, user2):
def create_from_two_users(cls, user1: CustomUser, user2: CustomUser) -> DirectChat:
chat = cls.objects.create(pk=cls.get_chat_id_from_users(user1, user2))
chat.users.set([user1, user2])
return chat

@classmethod
def get_chat_id_from_users(cls, user1, user2):
def get_chat_id_from_users(cls, user1: CustomUser, user2: CustomUser) -> str:
first_user = user1 if user1.pk < user2.pk else user2
second_user = user2 if user1.pk < user2.pk else user1
return f"{first_user.pk}_{second_user.pk}"

def __str__(self):
def __str__(self) -> str:
return f"DirectChat with {self.get_users_str()}"

class Meta:
Expand All @@ -200,7 +202,7 @@ class BaseMessage(models.Model):
is_edited = models.BooleanField(default=False)
created_at = models.DateTimeField(auto_now_add=True)

def __str__(self):
def __str__(self) -> str:
return f"Message<{self.pk}>"

class Meta:
Expand Down Expand Up @@ -240,7 +242,7 @@ def clean(self):
if self.reply_to and self.reply_to.chat != self.chat:
raise ValidationError("Reply to message from another chat")

def __str__(self):
def __str__(self) -> str:
return f"ProjectChatMessage<{self.pk}>"

class Meta:
Expand Down Expand Up @@ -278,7 +280,7 @@ def clean(self):
if self.reply_to and self.reply_to.chat != self.chat:
raise ValidationError("Reply to message from another chat")

def __str__(self):
def __str__(self) -> str:
return f"DirectChatMessage<{self.pk}>"

class Meta:
Expand Down Expand Up @@ -307,7 +309,7 @@ class FileToMessage(models.Model):
null=True,
)

def __str__(self):
def __str__(self) -> str:
return f"FileToMessage<{self.file}>"

class Meta:
Expand Down
26 changes: 13 additions & 13 deletions chats/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ class DirectChatListSerializer(serializers.ModelSerializer):
name = serializers.SerializerMethodField(read_only=True)
image_address = serializers.SerializerMethodField(read_only=True)

def get_opponent(self, chat: DirectChat):
def get_opponent(self, chat: DirectChat) -> dict:
user = self.context.get("opponent")
return UserDetailSerializer(
user, context={"request": self.context.get("request")}
).data

def get_name(self, chat: DirectChat):
def get_name(self, chat: DirectChat) -> str:
user = self.context.get("opponent")
return user.get_full_name()

def get_image_address(self, chat: DirectChat):
def get_image_address(self, chat: DirectChat) -> str:
user = self.context.get("opponent")
return user.avatar

@classmethod
def get_last_message(cls, chat: DirectChat):
def get_last_message(cls, chat: DirectChat) -> dict:
return DirectChatMessageListSerializer(chat.get_last_message()).data

class Meta:
Expand All @@ -42,7 +42,7 @@ class Meta:
class DirectChatDetailSerializer(serializers.ModelSerializer):
opponent = serializers.SerializerMethodField()

def get_opponent(self, chat: DirectChat):
def get_opponent(self, chat: DirectChat) -> dict:
user = self.context.get("opponent")
return UserDetailSerializer(
user, context={"request": self.context.get("request")}
Expand All @@ -62,15 +62,15 @@ class ProjectChatListSerializer(serializers.ModelSerializer):
image_address = serializers.SerializerMethodField(read_only=True)

@classmethod
def get_image_address(cls, chat: ProjectChat):
def get_image_address(cls, chat: ProjectChat) -> str:
return chat.project.image_address

@classmethod
def get_name(cls, chat: ProjectChat):
def get_name(cls, chat: ProjectChat) -> str:
return chat.project.name

@classmethod
def get_last_message(cls, chat: ProjectChat):
def get_last_message(cls, chat: ProjectChat) -> dict:
return ProjectChatMessageListSerializer(chat.get_last_message()).data

class Meta:
Expand All @@ -84,14 +84,14 @@ class ProjectChatDetailSerializer(serializers.ModelSerializer):
image_address = serializers.SerializerMethodField(read_only=True)

@classmethod
def get_image_address(cls, chat: ProjectChat):
def get_image_address(cls, chat: ProjectChat) -> str:
return chat.project.image_address

@classmethod
def get_name(cls, chat: ProjectChat):
def get_name(cls, chat: ProjectChat) -> str:
return chat.project.name

def get_users(self, chat: ProjectChat):
def get_users(self, chat: ProjectChat) -> dict:
return UserListSerializer(
chat.get_users(), context={"request": self.context.get("request")}, many=True
).data
Expand Down Expand Up @@ -133,7 +133,7 @@ class DirectChatMessageListSerializer(serializers.ModelSerializer):
files = serializers.SerializerMethodField()

@classmethod
def get_files(cls, message: DirectChatMessage):
def get_files(cls, message: DirectChatMessage) -> list[dict]:
data = []
for file_to_message in message.file_to_message.all():
file_data = UserFileSerializer(file_to_message.file).data
Expand Down Expand Up @@ -182,7 +182,7 @@ class ProjectChatMessageListSerializer(serializers.ModelSerializer):
files = serializers.SerializerMethodField()

@classmethod
def get_files(cls, message: DirectChatMessage):
def get_files(cls, message: DirectChatMessage) -> dict:
data = []
for file_to_message in message.file_to_message.all():
file_data = UserFileSerializer(file_to_message.file).data
Expand Down
8 changes: 5 additions & 3 deletions chats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
WrongChatIdException,
NonMatchingDirectChatIdException,
)
from chats.models import DirectChatMessage, ProjectChatMessage, FileToMessage
from chats.models import DirectChatMessage, ProjectChatMessage, FileToMessage, BaseMessage
from files.models import UserFile

User = get_user_model()
Expand Down Expand Up @@ -106,7 +106,9 @@ async def create_file_to_message(
)


async def match_files_and_messages(file_urls, messages):
async def match_files_and_messages(
file_urls: list[str], messages: dict[str, Union[str, None, ProjectChatMessage]]
):
for url in file_urls:
file = await sync_to_async(UserFile.objects.get)(pk=url)
# implicitly matches a file and a message
Expand All @@ -117,7 +119,7 @@ async def match_files_and_messages(file_urls, messages):
)


def get_all_files(messages):
def get_all_files(messages: list[BaseMessage]) -> list[str]:
# looks like something bad -
files = []
for message in messages:
Expand Down
Loading