diff --git a/src/escalus_session.erl b/src/escalus_session.erl index 90d5247..0708b01 100644 --- a/src/escalus_session.erl +++ b/src/escalus_session.erl @@ -78,16 +78,19 @@ authenticate(Client = #client{props = Props}) -> %% but as a default we use plain, as it incurrs lower load and better logs (no hashing) %% for common setups. If a different mechanism is required then it should be configured on the %% user specification. - {M, F} = proplists:get_value(auth, Props, {escalus_auth, auth_plain}), - PropsAfterAuth = case apply(M, F, [Client, Props]) of - ok -> Props; - {ok, P} when is_list(P) -> P - end, + PropsAfterAuth = apply_auth_method(Client, Props), escalus_connection:reset_parser(Client), Client1 = escalus_session:start_stream(Client#client{props = PropsAfterAuth}), escalus_session:stream_features(Client1, []), Client1. +apply_auth_method(Client, Props) -> + Fun = proplists:get_value(auth, Props, fun escalus_auth:auth_plain/2), + case apply(Fun, [Client, Props]) of + ok -> Props; + {ok, P} when is_list(P) -> P + end. + -spec bind(client()) -> client(). bind(Client = #client{props = Props0}) -> Resource = proplists:get_value(resource, Props0, ?DEFAULT_RESOURCE), diff --git a/src/escalus_users.erl b/src/escalus_users.erl index 6efa64b..51a6f38 100644 --- a/src/escalus_users.erl +++ b/src/escalus_users.erl @@ -154,46 +154,50 @@ get_server(Config, User) -> get_wspath(Config, User) -> get_user_option(wspath, User, escalus_wspath, Config, undefined). --spec get_auth_method(escalus:config(), user()) -> {module(), atom()}. +-spec get_auth_method(escalus:config(), user()) -> + fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}). get_auth_method(Config, User) -> AuthMethod = get_user_option(auth_method, User, escalus_auth_method, Config, <<"PLAIN">>), get_auth_method(AuthMethod). --spec get_auth_method(binary() | {module(), atom()}) -> {module(), atom()}. +-spec get_auth_method(binary() | {module(), atom()}) -> + fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}). get_auth_method(<<"PLAIN">>) -> - {escalus_auth, auth_plain}; + fun escalus_auth:auth_plain/2; get_auth_method(<<"DIGEST-MD5">>) -> - {escalus_auth, auth_digest_md5}; + fun escalus_auth:auth_digest_md5/2; get_auth_method(<<"SASL-ANON">>) -> - {escalus_auth, auth_sasl_anon}; + fun escalus_auth:auth_sasl_anon/2; %% SCRAM Regular get_auth_method(<<"SCRAM-SHA-1">>) -> - {escalus_auth, auth_sasl_scram_sha1}; + fun escalus_auth:auth_sasl_scram_sha1/2; get_auth_method(<<"SCRAM-SHA-224">>) -> - {escalus_auth, auth_sasl_scram_sha224}; + fun escalus_auth:auth_sasl_scram_sha224/2; get_auth_method(<<"SCRAM-SHA-256">>) -> - {escalus_auth, auth_sasl_scram_sha256}; + fun escalus_auth:auth_sasl_scram_sha256/2; get_auth_method(<<"SCRAM-SHA-384">>) -> - {escalus_auth, auth_sasl_scram_sha384}; + fun escalus_auth:auth_sasl_scram_sha384/2; get_auth_method(<<"SCRAM-SHA-512">>) -> - {escalus_auth, auth_sasl_scram_sha512}; + fun escalus_auth:auth_sasl_scram_sha512/2; %% SCRAM PLUS get_auth_method(<<"SCRAM-SHA-1-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha1_plus}; + fun escalus_auth:auth_sasl_scram_sha1_plus/2; get_auth_method(<<"SCRAM-SHA-224-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha224_plus}; + fun escalus_auth:auth_sasl_scram_sha224_plus/2; get_auth_method(<<"SCRAM-SHA-256-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha256_plus}; + fun escalus_auth:auth_sasl_scram_sha256_plus/2; get_auth_method(<<"SCRAM-SHA-384-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha384_plus}; + fun escalus_auth:auth_sasl_scram_sha384_plus/2; get_auth_method(<<"SCRAM-SHA-512-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha512_plus}; + fun escalus_auth:auth_sasl_scram_sha512_plus/2; get_auth_method(<<"X-OAUTH">>) -> - {escalus_auth, auth_sasl_oauth}; + fun escalus_auth:auth_sasl_oauth/2; get_auth_method({Mod, Fun}) when is_atom(Mod), is_atom(Fun) -> - {Mod, Fun}. + fun Mod:Fun/2; +get_auth_method(Fun) when is_function(Fun, 2) -> + Fun. -spec get_usp(escalus:config(), user()) -> [binary() | xmpp_domain()]. get_usp(Config, User) ->