Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Config:

#: Postgres credentials
POSTGRES = {}

#: Shared secret for LVSP
LVSP_SECRET = ""

Expand Down
2 changes: 1 addition & 1 deletion litecord/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def token_check(to_raise: Literal[False] = ...) -> Optional[int]:
...


async def token_check(to_raise = True) -> Optional[int]:
async def token_check(to_raise=True) -> Optional[int]:
"""Check token information."""
# first, check if the request info already has a uid
user_id = getattr(request, "user_id", None)
Expand Down
1 change: 0 additions & 1 deletion litecord/blueprints/admin_api/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from quart import current_app as app, request



bp = Blueprint("channels_admin", __name__)


Expand Down
8 changes: 2 additions & 6 deletions litecord/blueprints/admin_api/guilds.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ async def create_guild():
},
)
guild_id = j.get("id") or app.winter_factory.snowflake()
guild, extra = await handle_guild_create(
user_id, guild_id, {"features": j.get("features")}
)
guild, extra = await handle_guild_create(user_id, guild_id, {"features": j.get("features")})
return jsonify({**guild, **extra}), 201


Expand Down Expand Up @@ -166,9 +164,7 @@ async def update_guild(guild_id: int):
if old_unavailable and not new_unavailable:
# Guild became available
guild = await app.storage.get_guild_full(guild_id)
await app.dispatcher.guild.dispatch(
guild_id, ("GUILD_CREATE", {**guild, "unavailable": False})
)
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", {**guild, "unavailable": False}))
elif not old_unavailable and new_unavailable:
# Guild became unavailable
await app.dispatcher.guild.dispatch(
Expand Down
6 changes: 1 addition & 5 deletions litecord/blueprints/admin_api/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ async def get_db_url():
if host in ("localhost", "0.0.0.0"):
host = app.config["MAIN_URL"]

return jsonify(
{
"url": f"postgres://{db['user']}:{db['password']}@{host}:5432/{db['database']}"
}
)
return jsonify({"url": f"postgres://{db['user']}:{db['password']}@{host}:5432/{db['database']}"})


@bp.route("/snowflake", methods=["GET"])
Expand Down
4 changes: 1 addition & 3 deletions litecord/blueprints/admin_api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@
async def _create_user():
await admin_check()
j = validate(await request.get_json(), USER_CREATE)
user_id, _ = await create_user(
j["username"], j["email"], j["password"], j.get("date_of_birth"), id=j.get("id")
)
user_id, _ = await create_user(j["username"], j["email"], j["password"], j.get("date_of_birth"), id=j.get("id"))
return jsonify(await app.storage.get_user(user_id, True)), 201


Expand Down
4 changes: 1 addition & 3 deletions litecord/blueprints/attachments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
ATTACHMENTS = Path.cwd() / "attachments"


async def _resize_gif(
attach_id: int, resized_path: Path, width: int, height: int
) -> str:
async def _resize_gif(attach_id: int, resized_path: Path, width: int, height: int) -> str:
"""Resize a GIF attachment."""

# get original gif bytes
Expand Down
24 changes: 11 additions & 13 deletions litecord/blueprints/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from datetime import datetime, date
import itsdangerous
import bcrypt
from quart import Blueprint, jsonify, request, current_app as app
from quart import Blueprint, jsonify
from typing import TYPE_CHECKING
from logbook import Logger


Expand All @@ -33,6 +34,11 @@
from litecord.pubsub.user import dispatch_user
from .invites import use_invite

if TYPE_CHECKING:
from litecord.typing_hax import app, request
else:
from quart import current_app as app, request

log = Logger(__name__)
bp = Blueprint("auth", __name__)

Expand All @@ -42,9 +48,7 @@ async def check_password(pwd_hash: str, given_password: str) -> bool:
pwd_encoded = pwd_hash.encode()
given_encoded = given_password.encode()

return await app.loop.run_in_executor(
None, bcrypt.checkpw, given_encoded, pwd_encoded
)
return await app.loop.run_in_executor(None, bcrypt.checkpw, given_encoded, pwd_encoded)


def make_token(user_id, user_pwd_hash) -> str:
Expand Down Expand Up @@ -86,9 +90,7 @@ async def register():
today = date.today()
date_of_birth = datetime.strptime(j["date_of_birth"], "%Y-%m-%d")
if (
today.year
- date_of_birth.year
- ((today.month, today.day) < (date_of_birth.month, date_of_birth.day))
today.year - date_of_birth.year - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day))
) < 13:
raise ManualFormError(
date_of_birth={
Expand Down Expand Up @@ -145,9 +147,7 @@ async def _register_with_invite():
today = date.today()
date_of_birth = datetime.strptime(data["date_of_birth"], "%Y-%m-%d")
if (
today.year
- date_of_birth.year
- ((today.month, today.day) < (date_of_birth.month, date_of_birth.day))
today.year - date_of_birth.year - ((today.month, today.day) < (date_of_birth.month, date_of_birth.day))
) < 13:
raise ManualFormError(
date_of_birth={
Expand All @@ -165,9 +165,7 @@ async def _register_with_invite():
invcode,
)

user_id, pwd_hash = await create_user(
data["username"], data["email"], data["password"], date_of_birth
)
user_id, pwd_hash = await create_user(data["username"], data["email"], data["password"], date_of_birth)

return jsonify({"token": make_token(user_id, pwd_hash)})

Expand Down
83 changes: 19 additions & 64 deletions litecord/blueprints/channel/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,8 @@ async def around_message_search(

around_message = await app.storage.get_message(around_id, user_id)
around_message = [around_message] if around_message else []
before_messages = await message_search(
channel_id, halved_limit, before=around_id, order="DESC"
)
after_messages = await message_search(
channel_id, halved_limit, after=around_id, order="ASC"
)
before_messages = await message_search(channel_id, halved_limit, before=around_id, order="DESC")
after_messages = await message_search(channel_id, halved_limit, after=around_id, order="ASC")
return list(reversed(before_messages)) + around_message + after_messages


Expand All @@ -132,9 +128,7 @@ async def handle_get_messages(channel_id: int):
limit = extract_limit(request, default=50)

if "around" in request.args:
messages = await around_message_search(
channel_id, int(request.args["around"]), limit
)
messages = await around_message_search(channel_id, int(request.args["around"]), limit)
else:
before, after = query_tuple_from_args(request.args, limit)
messages = await message_search(channel_id, limit, before=before, after=after)
Expand Down Expand Up @@ -216,11 +210,7 @@ async def create_message(
mentions = []
mention_roles = []
if data.get("content"):
if (
allowed_mentions is None
or "users" in allowed_mentions.get("parse", [])
or allowed_mentions.get("users")
):
if allowed_mentions is None or "users" in allowed_mentions.get("parse", []) or allowed_mentions.get("users"):
allowed = (allowed_mentions.get("users") or []) if allowed_mentions else []
if ctype == ChannelType.GROUP_DM:
members = await app.db.fetch(
Expand Down Expand Up @@ -264,9 +254,7 @@ async def create_message(
mentions.append(found_id)

if actual_guild_id and (
allowed_mentions is None
or "roles" in allowed_mentions.get("parse", [])
or allowed_mentions.get("roles")
allowed_mentions is None or "roles" in allowed_mentions.get("parse", []) or allowed_mentions.get("roles")
):
guild_roles = await app.db.fetch(
"""
Expand Down Expand Up @@ -297,8 +285,7 @@ async def create_message(

if (
data.get("message_reference")
and not data.get("flags", 0) & MessageFlags.is_crosspost
== MessageFlags.is_crosspost
and not data.get("flags", 0) & MessageFlags.is_crosspost == MessageFlags.is_crosspost
and (allowed_mentions is None or allowed_mentions.get("replied_user", False))
):
reply_id = await app.db.fetchval(
Expand Down Expand Up @@ -330,9 +317,7 @@ async def create_message(
data["tts"],
data["everyone_mention"],
data["nonce"],
MessageType.DEFAULT.value
if not data.get("message_reference")
else MessageType.REPLY.value,
MessageType.DEFAULT.value if not data.get("message_reference") else MessageType.REPLY.value,
data.get("flags") or 0,
data.get("embeds") or [],
data.get("message_reference") or None,
Expand Down Expand Up @@ -464,25 +449,16 @@ async def _create_message(channel_id):
# guild_id is the dm's peer_id
await dm_pre_check(user_id, channel_id, guild_id)

can_everyone = (
await channel_perm_check(user_id, channel_id, "mention_everyone", False)
and ctype != ChannelType.DM
)
can_everyone = await channel_perm_check(user_id, channel_id, "mention_everyone", False) and ctype != ChannelType.DM

mentions_everyone = ("@everyone" in j["content"]) and can_everyone
mentions_here = ("@here" in j["content"]) and can_everyone

is_tts = j.get("tts", False) and await channel_perm_check(
user_id, channel_id, "send_tts_messages", False
)
is_tts = j.get("tts", False) and await channel_perm_check(user_id, channel_id, "send_tts_messages", False)

embeds = [
await fill_embed(embed)
for embed in (
(j.get("embeds") or []) or [j["embed"]]
if "embed" in j and j["embed"]
else []
)
for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else [])
]
message_id = await create_message(
channel_id,
Expand All @@ -500,10 +476,7 @@ async def _create_message(channel_id):
"allowed_mentions": j.get("allowed_mentions"),
"sticker_ids": j.get("sticker_ids"),
"flags": MessageFlags.suppress_embeds
if (
j.get("flags", 0) & MessageFlags.suppress_embeds
== MessageFlags.suppress_embeds
)
if (j.get("flags", 0) & MessageFlags.suppress_embeds == MessageFlags.suppress_embeds)
else 0,
},
recipient_id=guild_id if ctype == ChannelType.DM else None,
Expand Down Expand Up @@ -541,9 +514,7 @@ async def _create_message(channel_id):
)

if ctype not in (ChannelType.DM, ChannelType.GROUP_DM):
await msg_guild_text_mentions(
payload, guild_id, mentions_everyone, mentions_here
)
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)

return jsonify(message_view(payload))

Expand Down Expand Up @@ -578,9 +549,7 @@ async def edit_message(channel_id, message_id):
old_flags = MessageFlags.from_int(old_message.get("flags", 0))
new_flags = MessageFlags.from_int(int(j["flags"]))

toggle_flag(
old_flags, MessageFlags.suppress_embeds, new_flags.is_suppress_embeds
)
toggle_flag(old_flags, MessageFlags.suppress_embeds, new_flags.is_suppress_embeds)

if old_flags.value != old_message["flags"]:
await app.db.execute(
Expand Down Expand Up @@ -610,11 +579,7 @@ async def edit_message(channel_id, message_id):
updated = True
embeds = [
await fill_embed(embed)
for embed in (
(j.get("embeds") or []) or [j["embed"]]
if "embed" in j and j["embed"]
else []
)
for embed in ((j.get("embeds") or []) or [j["embed"]] if "embed" in j and j["embed"] else [])
]
await app.db.execute(
"""
Expand All @@ -637,9 +602,7 @@ async def edit_message(channel_id, message_id):
"channel_id": channel_id,
"content": j["content"],
"embeds": old_message["embeds"],
"flags": flags
if flags is not None
else old_message.get("flags", 0),
"flags": flags if flags is not None else old_message.get("flags", 0),
},
delay=0.2,
)
Expand All @@ -663,9 +626,7 @@ async def edit_message(channel_id, message_id):
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))

# now we handle crossposted messages
if updated and (
message.get("flags", 0) & MessageFlags.crossposted == MessageFlags.crossposted
):
if updated and (message.get("flags", 0) & MessageFlags.crossposted == MessageFlags.crossposted):
async with app.db.acquire() as conn:
await pg_set_json(conn)

Expand Down Expand Up @@ -708,17 +669,13 @@ async def edit_message(channel_id, message_id):
"id": id,
"channel_id": row["channel_id"],
"content": j["content"],
"embeds": embeds
if embeds is not None
else old_message["embeds"],
"embeds": embeds if embeds is not None else old_message["embeds"],
},
delay=0.2,
)

message = await app.storage.get_message(id)
await app.dispatcher.channel.dispatch(
row["channel_id"], ("MESSAGE_UPDATE", message)
)
await app.dispatcher.channel.dispatch(row["channel_id"], ("MESSAGE_UPDATE", message))

return jsonify(message_view(message))

Expand Down Expand Up @@ -778,9 +735,7 @@ async def _del_msg_fkeys(message_id: int, channel_id: int):
)

message = await app.storage.get_message(id)
await app.dispatcher.channel.dispatch(
channel_id, ("MESSAGE_UPDATE", message)
)
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))

# take the chance and delete all the data from the other tables too!

Expand Down
8 changes: 2 additions & 6 deletions litecord/blueprints/channel/pins.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ async def _dispatch_pins_update(channel_id: int) -> None:
channel_id,
)

timestamp = (
app.winter_factory.to_datetime(message_id) if message_id is not None else None
)
timestamp = app.winter_factory.to_datetime(message_id) if message_id is not None else None
await app.dispatcher.channel.dispatch(
channel_id,
(
Expand Down Expand Up @@ -114,9 +112,7 @@ async def add_pin(channel_id, message_id):

await _dispatch_pins_update(channel_id)

await send_sys_message(
channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
)
await send_sys_message(channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id)

return "", 204

Expand Down
Loading