Skip to content

Commit

Permalink
Invoke UserSocket connect/2 and merge socket assigns
Browse files Browse the repository at this point in the history
Send 403 from WS upgrade and LP start if non-ok connection.
Update socket params to be serialized and sent on conn open.
  • Loading branch information
chrismccord committed Jul 13, 2015
1 parent 4f18811 commit 1f5b8e4
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 109 deletions.
17 changes: 1 addition & 16 deletions lib/phoenix/channel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ defmodule Phoenix.Channel do
quote do
@behaviour unquote(__MODULE__)
import unquote(__MODULE__)
import Phoenix.Socket, only: [assign: 3]

def handle_in(_event, _message, socket) do
{:noreply, socket}
Expand Down Expand Up @@ -314,20 +315,4 @@ defmodule Phoenix.Channel do
end
"""
end

@doc """
Adds key/value pair to socket assigns.
## Examples
iex> socket.assigns[:token]
nil
iex> socket = assign(socket, :token, "bar")
iex> socket.assigns[:token]
"bar"
"""
def assign(socket = %Socket{}, key, value) do
update_in socket.assigns, &Map.put(&1, key, value)
end
end
15 changes: 8 additions & 7 deletions lib/phoenix/channel/transport.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ defmodule Phoenix.Channel.Transport do
* `{:error, reason}` - Unauthorized or unmatched dispatch
"""
def dispatch(%Message{} = msg, sockets, transport_pid, socket_handler, endpoint, transport) do
def dispatch(%Message{} = msg, sockets, transport_pid, socket_handler, socket, endpoint, transport) do
sockets
|> HashDict.get(msg.topic)
|> dispatch(msg, transport_pid, socket_handler, endpoint, transport)
|> dispatch(msg, transport_pid, socket_handler, socket, endpoint, transport)
end

@doc """
Expand All @@ -84,15 +84,16 @@ defmodule Phoenix.Channel.Transport do
The server will respond to heartbeats with the same message
"""
def dispatch(_, %{ref: ref, topic: "phoenix", event: "heartbeat"}, transport_pid, _socket_handler, _pubsub_server, _transport) do
def dispatch(_, %{ref: ref, topic: "phoenix", event: "heartbeat"}, transport_pid, _socket_handler, _socket, _pubsub_server, _transport) do
reply(transport_pid, ref, "phoenix", %{status: :ok, response: %{}})
:ok
end
def dispatch(nil, %{event: "phx_join"} = msg, transport_pid, socket_handler, endpoint, transport) do
def dispatch(nil, %{event: "phx_join"} = msg, transport_pid, socket_handler, base_socket, endpoint, transport) do
case socket_handler.channel_for_topic(msg.topic, transport) do
nil -> log_ignore(msg.topic, socket_handler)
channel ->
socket = %Socket{transport_pid: transport_pid,
assigns: base_socket.assigns,
endpoint: endpoint,
pubsub_server: endpoint.__pubsub_server__(),
topic: msg.topic,
Expand All @@ -119,14 +120,14 @@ defmodule Phoenix.Channel.Transport do
end
end
end
def dispatch(nil, msg, _transport_pid, socket_handler, _pubsub_server, _transport) do
def dispatch(nil, msg, _transport_pid, socket_handler, _socket, _pubsub_server, _transport) do
log_ignore(msg.topic, socket_handler)
end
def dispatch(socket_pid, %{event: "phx_leave", ref: ref}, _transport_pid, _socket_handler, _pubsub_server, _transport) do
def dispatch(socket_pid, %{event: "phx_leave", ref: ref}, _transport_pid, _socket_handler, _socket, _pubsub_server, _transport) do
Phoenix.Channel.Server.leave(socket_pid, ref)
:ok
end
def dispatch(socket_pid, msg, _transport_pid, _socket_handler, _pubsub_server, _transport) do
def dispatch(socket_pid, msg, _transport_pid, _socket_handler, _socket, _pubsub_server, _transport) do
send(socket_pid, msg)
:ok
end
Expand Down
6 changes: 4 additions & 2 deletions lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,14 @@ defmodule Phoenix.Endpoint do
plugs = Module.get_attribute(env.module, :plugs)
{conn, body} = Plug.Builder.compile(env, plugs, [])

socket_intercepts = for {path, mod} <- sockets do
socket_intercepts = for {path, module} <- sockets do
path_info = Plug.Router.Utils.split(path)

quote do
defp phoenix_pipeline(%Plug.Conn{path_info: unquote(path_info)} = conn) do
Phoenix.Socket.Router.call(conn, Phoenix.Socket.Router.init(unquote(mod)))
conn
|> Plug.Conn.put_private(:phoenix_socket_handler, unquote(module))
|> Phoenix.Socket.Router.call(Phoenix.Socket.Router.init([]))
end
end
end
Expand Down
20 changes: 18 additions & 2 deletions lib/phoenix/socket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ defmodule Phoenix.Socket do
alias Phoenix.Socket
alias Phoenix.Socket.Helpers

defcallback connect(params :: map) :: {:ok, socket_assigns :: map} |
{:error, reason :: map}
defcallback connect(params :: map) :: {:ok, Socket.t} |
:error

defcallback id(socket_assigns :: map) :: String.t

Expand Down Expand Up @@ -92,6 +92,22 @@ defmodule Phoenix.Socket do
end
end

@doc """
Adds key/value pair to socket assigns.
## Examples
iex> socket.assigns[:token]
nil
iex> socket = assign(socket, :token, "bar")
iex> socket.assigns[:token]
"bar"
"""
def assign(socket = %Socket{}, key, value) do
update_in socket.assigns, &Map.put(&1, key, value)
end

@doc """
Defines a channel matching the given topic and transports.
Expand Down
24 changes: 13 additions & 11 deletions lib/phoenix/socket/router.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,16 @@ defmodule Phoenix.Socket.Router do
# Routes WebSocket and LongPoller requests.
@moduledoc false

import Plug.Conn
require Logger
use Plug.Builder
alias Phoenix.Transports.WebSocket
alias Phoenix.Transports.LongPoller

def init(opts), do: opts
plug Plug.Logger
plug :fetch_query_params
plug :dispatch

def call(conn, module) do
conn = conn |> fetch_query_params() |> put_private(:phoenix_socket, module)
transport = case conn.query_params["transport"] do
"poll" -> LongPoller
_ -> WebSocket
end

case {conn.method, transport} do
def dispatch(conn, _) do
case {conn.method, transport(conn)} do
{"GET", WebSocket} -> WebSocket.call(conn, [])
{"POST", WebSocket} -> WebSocket.call(conn, [])
{"OPTIONS", LongPoller} -> LongPoller.call(conn, :options)
Expand All @@ -25,4 +20,11 @@ defmodule Phoenix.Socket.Router do
_ -> conn |> send_resp(:bad_request, "") |> halt()
end
end

defp transport(conn) do
case conn.query_params["transport"] do
"poll" -> LongPoller
_ -> WebSocket
end
end
end
20 changes: 13 additions & 7 deletions lib/phoenix/transports/long_poller.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,17 @@ defmodule Phoenix.Transports.LongPoller do
end

defp new_session(conn) do
{conn, priv_topic, sig, _server_pid} = start_session(conn)
case conn.private.phoenix_socket_handler.connect(conn.params, %Phoenix.Socket{}) do
{:ok, socket} ->
{conn, priv_topic, sig, _server_pid} = start_session(conn, socket)

conn
|> put_status(:gone)
|> status_json(%{token: priv_topic, sig: sig})
conn
|> put_status(:gone)
|> status_json(%{token: priv_topic, sig: sig})

:error ->
conn |> put_status(:forbidden) |> status_json(%{})
end
end

@doc """
Expand Down Expand Up @@ -109,14 +115,14 @@ defmodule Phoenix.Transports.LongPoller do
@doc """
Starts the `Phoenix.LongPoller.Server` and stores the serialized pid in the session.
"""
def start_session(conn) do
socket_handler = Map.fetch!(conn.private, :phoenix_socket)
def start_session(conn, socket) do
socket_handler = conn.private.phoenix_socket_handler
priv_topic =
"phx:lp:"
|> Kernel.<>(Base.encode64(:crypto.strong_rand_bytes(16)))
|> Kernel.<>(:os.timestamp() |> Tuple.to_list |> Enum.join(""))

child = [socket_handler, timeout_window_ms(conn), priv_topic, endpoint_module(conn)]
child = [socket_handler, socket, timeout_window_ms(conn), priv_topic, endpoint_module(conn)]
{:ok, server_pid} = Supervisor.start_child(LongPoller.Supervisor, child)

{conn, priv_topic, sign(conn, priv_topic), server_pid}
Expand Down
11 changes: 7 additions & 4 deletions lib/phoenix/transports/long_poller/server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ defmodule Phoenix.Transports.LongPoller.Server do
Starts the Server.
* `socket_handler` - The socket handler module, ie. `MyApp.UserSocket`
* `socket` - The `%Phoenix.Socket{}` struct returend from `connect/2` of the
socket handler.
* `window_ms` - The longpoll session timeout, in milliseconds
If the server receives no message within `window_ms`, it terminates and
clients are responsible for opening a new session.
"""
def start_link(socket_handler, window_ms, priv_topic, endpoint) do
GenServer.start_link(__MODULE__, [socket_handler, window_ms, priv_topic, endpoint])
def start_link(socket_handler, socket, window_ms, priv_topic, endpoint) do
GenServer.start_link(__MODULE__, [socket_handler, socket, window_ms, priv_topic, endpoint])
end

@doc false
def init([socket_handler, window_ms, priv_topic, endpoint]) do
def init([socket_handler, socket, window_ms, priv_topic, endpoint]) do
Process.flag(:trap_exit, true)

state = %{buffer: [],
socket_handler: socket_handler,
socket: socket,
sockets: HashDict.new,
sockets_inverse: HashDict.new,
window_ms: trunc(window_ms * 1.5),
Expand All @@ -69,7 +72,7 @@ defmodule Phoenix.Transports.LongPoller.Server do
"""
def handle_info({:dispatch, msg, ref}, state) do
msg
|> Transport.dispatch(state.sockets, self, state.socket_handler, state.endpoint, LongPoller)
|> Transport.dispatch(state.sockets, self, state.socket_handler, state.socket, state.endpoint, LongPoller)
|> case do
{:ok, socket_pid} ->
:ok = broadcast_from(state, {:ok, :dispatch, ref})
Expand Down
25 changes: 18 additions & 7 deletions lib/phoenix/transports/websocket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,30 @@ defmodule Phoenix.Transports.WebSocket do
plug :upgrade

def upgrade(%Plug.Conn{method: "GET"} = conn, _) do
put_private(conn, :phoenix_upgrade, {:websocket, __MODULE__}) |> halt
case conn.private.phoenix_socket_handler.connect(conn.params, %Phoenix.Socket{}) do
{:ok, socket} ->
conn
|> put_private(:phoenix_upgrade, {:websocket, __MODULE__})
|> put_private(:phoenix_socket, socket)
|> halt()
:error ->
conn |> send_resp(403, "") |> halt()
end
end

@doc """
Handles initalization of the websocket.
"""
def ws_init(conn) do
Process.flag(:trap_exit, true)
endpoint = endpoint_module(conn)
serializer = Dict.fetch!(endpoint.config(:transports), :websocket_serializer)
timeout = Dict.fetch!(endpoint.config(:transports), :websocket_timeout)

{:ok, %{socket_handler: Map.fetch!(conn.private, :phoenix_socket),
endpoint = endpoint_module(conn)
serializer = Dict.fetch!(endpoint.config(:transports), :websocket_serializer)
timeout = Dict.fetch!(endpoint.config(:transports), :websocket_timeout)
socket_handler = conn.private.phoenix_socket_handler
socket = conn.private.phoenix_socket

{:ok, %{socket_handler: socket_handler,
socket: socket,
endpoint: endpoint,
sockets: HashDict.new,
sockets_inverse: HashDict.new,
Expand All @@ -63,7 +74,7 @@ defmodule Phoenix.Transports.WebSocket do
def ws_handle(opcode, payload, state) do
msg = state.serializer.decode!(payload, opcode)

case Transport.dispatch(msg, state.sockets, self, state.socket_handler, state.endpoint, __MODULE__) do
case Transport.dispatch(msg, state.sockets, self, state.socket_handler, state.socket, state.endpoint, __MODULE__) do
{:ok, socket_pid} ->
{:ok, put(state, msg.topic, socket_pid)}
:ok ->
Expand Down
Loading

0 comments on commit 1f5b8e4

Please sign in to comment.