Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 20 additions & 133 deletions lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ defmodule Phoenix.Endpoint do

# Channels

@doc """
Provide module specs for sockets to start and drain with the endpoint.

See `Phoenix.Socket.Transport` for more information.
"""
@callback sockets() :: [{module(), keyword()} | module()]

@doc """
Subscribes the caller to the given topic.

Expand Down Expand Up @@ -406,6 +413,8 @@ defmodule Phoenix.Endpoint do
"""
@callback local_broadcast_from(from :: pid, topic, event, msg) :: :ok

@optional_callbacks sockets: 0

@doc false
defmacro __using__(opts) do
quote do
Expand All @@ -424,9 +433,11 @@ defmodule Phoenix.Endpoint do

# Compile-time configuration checking
# This ensures that if a compile-time configuration is overwritten at runtime the application won't boot.
var!(code_reloading?) = Application.compile_env(@otp_app, [__MODULE__, :code_reloader], false)
var!(debug_errors?) = Application.compile_env(@otp_app, [__MODULE__, :debug_errors], false)
var!(force_ssl) = Application.compile_env(@otp_app, [__MODULE__, :force_ssl])
var!(code_reloading?) =
Application.compile_env(@otp_app, [__MODULE__, :code_reloader], false)

var!(debug_errors?) = Application.compile_env(@otp_app, [__MODULE__, :debug_errors], false)
var!(force_ssl) = Application.compile_env(@otp_app, [__MODULE__, :force_ssl])

# Avoid unused variable warnings
_ = var!(code_reloading?)
Expand Down Expand Up @@ -480,14 +491,13 @@ defmodule Phoenix.Endpoint do
use Plug.Builder, init_mode: Phoenix.plug_init_mode()
import Phoenix.Endpoint

Module.register_attribute(__MODULE__, :phoenix_sockets, accumulate: true)

if force_ssl = Phoenix.Endpoint.__force_ssl__(__MODULE__, var!(force_ssl)) do
plug Plug.SSL, force_ssl
end

if var!(debug_errors?) do
logo = ""
logo =
""

use Plug.Debugger,
otp_app: @otp_app,
Expand All @@ -502,8 +512,6 @@ defmodule Phoenix.Endpoint do
]
end

plug :socket_dispatch

# Compile after the debugger so we properly wrap it.
@before_compile Phoenix.Endpoint
end
Expand Down Expand Up @@ -640,19 +648,7 @@ defmodule Phoenix.Endpoint do
end

@doc false
defmacro __before_compile__(%{module: module}) do
sockets = Module.get_attribute(module, :phoenix_sockets)

dispatches =
for {path, socket, socket_opts} <- sockets,
{path, plug, conn_ast, plug_opts} <- socket_paths(module, path, socket, socket_opts) do
quote do
defp do_socket_dispatch(unquote(path), conn) do
halt(unquote(plug).call(unquote(conn_ast), unquote(Macro.escape(plug_opts))))
end
end
end

defmacro __before_compile__(_env) do
quote do
defoverridable call: 2

Expand Down Expand Up @@ -687,118 +683,9 @@ defmodule Phoenix.Endpoint do
)
end
end

@doc false
def __sockets__, do: unquote(Macro.escape(sockets))

@doc false
def socket_dispatch(%{path_info: path} = conn, _opts), do: do_socket_dispatch(path, conn)
unquote(dispatches)
defp do_socket_dispatch(_path, conn), do: conn
end
end

defp socket_paths(endpoint, path, socket, opts) do
paths = []

common_config = [
:path,
:serializer,
:transport_log,
:check_origin,
:check_csrf,
:code_reloader,
:connect_info,
:auth_token
]

websocket =
opts
|> Keyword.get(:websocket, true)
|> maybe_validate_keys(
common_config ++
[
:timeout,
:max_frame_size,
:fullsweep_after,
:compress,
:subprotocols,
:error_handler
]
)

longpoll =
opts
|> Keyword.get(:longpoll, true)
|> maybe_validate_keys(
common_config ++
[
:window_ms,
:pubsub_timeout_ms,
:crypto
]
)

paths =
if websocket do
websocket = put_auth_token(websocket, opts[:auth_token])
config = Phoenix.Socket.Transport.load_config(websocket, Phoenix.Transports.WebSocket)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
[{match_path, Phoenix.Transports.WebSocket, conn_ast, plug_init} | paths]
else
paths
end

paths =
if longpoll do
longpoll = put_auth_token(longpoll, opts[:auth_token])
config = Phoenix.Socket.Transport.load_config(longpoll, Phoenix.Transports.LongPoll)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
[{match_path, Phoenix.Transports.LongPoll, conn_ast, plug_init} | paths]
else
paths
end

paths
end

defp put_auth_token(true, enabled), do: [auth_token: enabled]
defp put_auth_token(opts, enabled), do: Keyword.put(opts, :auth_token, enabled)

defp socket_path(path, config) do
end_path_fragment = Keyword.fetch!(config, :path)

{vars, path} =
String.split(path <> "/" <> end_path_fragment, "/", trim: true)
|> Enum.join("/")
|> Plug.Router.Utils.build_path_match()

conn_ast =
if vars == [] do
quote do
conn
end
else
params =
for var <- vars,
param = Atom.to_string(var),
not match?("_" <> _, param),
do: {param, Macro.var(var, nil)}

quote do
params = %{unquote_splicing(params)}
%{conn | path_params: params, params: params}
end
end

{conn_ast, path}
end

defp maybe_validate_keys(opts, keys) when is_list(opts), do: Keyword.validate!(opts, keys)
defp maybe_validate_keys(other, _), do: other

## API

@doc """
Expand Down Expand Up @@ -1060,11 +947,11 @@ defmodule Phoenix.Endpoint do
by `Phoenix.Token`. By default tokens are valid for 2 weeks

"""
defmacro socket(path, module, opts \\ []) do
module = Macro.expand(module, %{__CALLER__ | function: {:socket_dispatch, 2}})
defmacro socket(_path, _module, _opts \\ []) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We most likely want to still support the old way as usual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I removed stuff in here mostly to make sure I catch everything. Making this propely backwards compatibility preserving will make sense once other questions are dealt with.

# module = Macro.expand(module, %{__CALLER__ | function: {:socket_dispatch, 2}})

quote do
@phoenix_sockets {unquote(path), unquote(module), unquote(opts)}
# @phoenix_sockets {unquote(path), unquote(module), unquote(opts)}
end
end

Expand Down
23 changes: 18 additions & 5 deletions lib/phoenix/endpoint/supervisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@
config_children(mod, secret_conf, default_conf) ++
warmup_children(mod) ++
pubsub_children(mod, conf) ++
socket_children(mod, conf, :child_spec) ++
socket_children(mod, :child_spec) ++
server_children(mod, conf, server?) ++
socket_children(mod, conf, :drainer_spec) ++
socket_children(mod, :drainer_spec) ++
watcher_children(mod, conf, server?)

Supervisor.init(children, strategy: :one_for_one)
end

Expand Down Expand Up @@ -117,15 +118,27 @@
end
end

defp socket_children(endpoint, conf, fun) do
for {_, socket, opts} <- Enum.uniq_by(endpoint.__sockets__(), &elem(&1, 1)),
_ = check_origin_or_csrf_checked!(conf, opts),
defp socket_children(endpoint, fun) do
sockets =
if function_exported?(endpoint, :sockets, 0) do
endpoint.sockets()
else
[]
end

for {socket, opts} <- Enum.map(sockets, &normalize_module_spec/1),
# TODO is this the correct place for this?
# Needs to know transport specific config
# _ = check_origin_or_csrf_checked!(conf, opts),
spec = apply_or_ignore(socket, fun, [[endpoint: endpoint] ++ opts]),
spec != :ignore do
spec
end
end

defp normalize_module_spec(module) when is_atom(module), do: {module, []}
defp normalize_module_spec({module, kw}) when is_atom(module) and is_list(kw), do: {module, kw}

defp apply_or_ignore(socket, fun, args) do
# If the module is not loaded, we want to invoke and crash
if not Code.ensure_loaded?(socket) or function_exported?(socket, fun, length(args)) do
Expand All @@ -135,7 +148,7 @@
end
end

defp check_origin_or_csrf_checked!(endpoint_conf, socket_opts) do

Check warning on line 151 in lib/phoenix/endpoint/supervisor.ex

View workflow job for this annotation

GitHub Actions / mix test (OTP 25.3.2.9 | Elixir 1.15.8)

function check_origin_or_csrf_checked!/2 is unused
check_origin = endpoint_conf[:check_origin]

for {transport, transport_opts} <- socket_opts, is_list(transport_opts) do
Expand Down
66 changes: 66 additions & 0 deletions lib/phoenix/socket/router.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
defmodule Phoenix.Socket.Router do
@moduledoc false

defmacro socket(path, user_socket, opts) do
websocket = Keyword.get(opts, :websocket, true)
longpoll = Keyword.get(opts, :longpoll, true)

ws_quote =
if websocket do
websocket = put_auth_token(websocket, opts[:auth_token])
{end_segment, websocket} = Keyword.pop(websocket, :path, "/websocket")
path = Path.join(path, end_segment)

quote do
websocket unquote(path), unquote(user_socket), unquote(websocket)
end
else
[]
end

lp_quote =
if longpoll do
longpoll = put_auth_token(longpoll, opts[:auth_token])
{end_segment, longpoll} = Keyword.pop(longpoll, :path, "/longpoll")
path = Path.join(path, end_segment)

quote do
longpoll unquote(path), unquote(user_socket), unquote(longpoll)
end
else
[]
end

quote do
unquote(ws_quote)
unquote(lp_quote)
end
end

defmacro websocket(path, user_socket, opts \\ []) do
quote do
match :*,
unquote(path),
Phoenix.Transports.WebSocket,
[
{:user_socket, unquote(user_socket)}
| unquote(opts)
]
end
end

defmacro longpoll(path, user_socket, opts \\ []) do
quote do
match :*,
unquote(path),
Phoenix.Transports.LongPoll,
[
{:user_socket, unquote(user_socket)}
| unquote(opts)
]
end
end

defp put_auth_token(true, enabled), do: [auth_token: enabled]
defp put_auth_token(opts, enabled), do: Keyword.put(opts, :auth_token, enabled)
end
Loading
Loading