diff --git a/rebar.config b/rebar.config index 3ecc346..b2e6741 100644 --- a/rebar.config +++ b/rebar.config @@ -12,7 +12,6 @@ {uuid, "2.0.7", {pkg, uuid_erl}}, {gun, "2.1.0"}, {worker_pool, "6.4.0"}, - {fast_tls, "1.1.21"}, {fast_scram, "0.6.1"} ]}. diff --git a/rebar.lock b/rebar.lock index 7189936..fa9442e 100644 --- a/rebar.lock +++ b/rebar.lock @@ -4,10 +4,8 @@ {<<"exml">>,{pkg,<<"hexml">>,<<"3.4.1">>},0}, {<<"fast_pbkdf2">>,{pkg,<<"fast_pbkdf2">>,<<"1.0.6">>},1}, {<<"fast_scram">>,{pkg,<<"fast_scram">>,<<"0.6.1">>},0}, - {<<"fast_tls">>,{pkg,<<"fast_tls">>,<<"1.1.21">>},0}, {<<"gun">>,{pkg,<<"gun">>,<<"2.1.0">>},0}, {<<"meck">>,{pkg,<<"meck">>,<<"1.0.0">>},0}, - {<<"p1_utils">>,{pkg,<<"p1_utils">>,<<"1.0.26">>},1}, {<<"quickrand">>,{pkg,<<"quickrand">>,<<"2.0.7">>},1}, {<<"uuid">>,{pkg,<<"uuid_erl">>,<<"2.0.7">>},0}, {<<"worker_pool">>,{pkg,<<"worker_pool">>,<<"6.4.0">>},0}]}. @@ -18,10 +16,8 @@ {<<"exml">>, <<"9581FE6512D9772C61BBE611CD4A8E5BB90B4D4481275325EC520F7A931A9393">>}, {<<"fast_pbkdf2">>, <<"199BCEC73A1A246941E9465D3DC41052953B638128841ED24B29ED03CF70AF27">>}, {<<"fast_scram">>, <<"BEEADB03D774640F0671681759CE53B2FF33CB58C86FD9BF2A793E2FC1ED0F5D">>}, - {<<"fast_tls">>, <<"65D7D547A09EEFB37A1C0D04D8601FAC4F3E6E2C1EDE859A7787081670F9648D">>}, {<<"gun">>, <<"B4E4CBBF3026D21981C447E9E7CA856766046EFF693720BA43114D7F5DE36E87">>}, {<<"meck">>, <<"24676CB6EE6951530093A93EDCD410CFE4CB59FE89444B875D35C9D3909A15D0">>}, - {<<"p1_utils">>, <<"67B0C4AC9FA3BA3EF563B31AA111B0A004439A37FAC85E027F1C3617E1C7EC6C">>}, {<<"quickrand">>, <<"D2BD76676A446E6A058D678444B7FDA1387B813710D1AF6D6E29BB92186C8820">>}, {<<"uuid">>, <<"B2078D2CC814F53AFA52D36C91E08962C7E7373585C623F4C0EA6DFB04B2AF94">>}, {<<"worker_pool">>, <<"0347B805A8E5804B5676A9885FB3B9B6C1627099C449C3C67C0E8E6AF79E9AA6">>}]}, @@ -31,10 +27,8 @@ {<<"exml">>, <<"D8E7894E2544402B4986EEB2443C15B51B14F686266F091DBF2777D1D99A2FA2">>}, {<<"fast_pbkdf2">>, <<"35EEC22629AAA739915843C7B7DE0D84657D1ECE972D8BBC86368747E9C14012">>}, {<<"fast_scram">>, <<"FE0650A309FDF97C75E1EA812CCFB40EB464ECAFD3783E83AA17C7F572EDAB0B">>}, - {<<"fast_tls">>, <<"131542913937025E48CD80AA81F00359686D5501B75621E72026A87B5229505B">>}, {<<"gun">>, <<"52FC7FC246BFC3B00E01AEA1C2854C70A366348574AB50C57DFE796D24A0101D">>}, {<<"meck">>, <<"680A9BCFE52764350BEB9FB0335FB75FEE8E7329821416CEE0A19FEC35433882">>}, - {<<"p1_utils">>, <<"D0379E8C1156B98BD64F8129C1DE022FCCA4F2FDB7486CE73BF0ED2C3376B04C">>}, {<<"quickrand">>, <<"B8ACBF89A224BC217C3070CA8BEBC6EB236DBE7F9767993B274084EA044D35F0">>}, {<<"uuid">>, <<"4E4C5CA3461DC47C5E157ED42AA3981A053B7A186792AF972A27B14A9489324E">>}, {<<"worker_pool">>, <<"59946FBCE1D331CDEB153EDD36A823DC1AAB4C2482662582B983C9C90EBC3461">>}]} diff --git a/src/escalus.app.src b/src/escalus.app.src index e43c97e..6530a26 100644 --- a/src/escalus.app.src +++ b/src/escalus.app.src @@ -12,7 +12,6 @@ meck, bbmustache, uuid, - fast_tls, fast_scram, worker_pool ]}, diff --git a/src/escalus_auth.erl b/src/escalus_auth.erl index ab850d5..bb930a5 100644 --- a/src/escalus_auth.erl +++ b/src/escalus_auth.erl @@ -23,14 +23,15 @@ auth_sasl_oauth/2]). %% Useful helpers for writing own mechanisms --export([get_challenge/2, +-export([auth_sasl_scram/3, + get_challenge/2, wait_for_success/2]). %% Some shorthands -type client() :: escalus_connection:client(). -type user_spec() :: escalus_users:user_spec(). -type hash_type() :: fast_scram:sha_type(). --type plus_variant() :: none | tls_unique. +-type plus_variant() :: undefined | none | tls_exporter. -type scram_options() :: #{plus_variant := plus_variant(), hash_type := hash_type(), xmpp_method := binary() @@ -38,6 +39,8 @@ -include_lib("exml/include/exml.hrl"). +-define(CB_LABEL, <<"EXPORTER-Channel-Binding">>). + %%-------------------------------------------------------------------- %% Public API %%-------------------------------------------------------------------- @@ -66,57 +69,57 @@ auth_digest_md5(Conn, Props) -> %% SCRAM Regular -spec auth_sasl_scram_sha1(client(), user_spec()) -> ok. auth_sasl_scram_sha1(Conn, Props) -> - Options = #{plus_variant => none, hash_type => sha, xmpp_method => <<"SCRAM-SHA-1">>}, + Options = #{plus_variant => undefined, hash_type => sha, xmpp_method => <<"SCRAM-SHA-1">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha224(client(), user_spec()) -> ok. auth_sasl_scram_sha224(Conn, Props) -> - Options = #{plus_variant => none, hash_type => sha224, xmpp_method => <<"SCRAM-SHA-224">>}, + Options = #{plus_variant => undefined, hash_type => sha224, xmpp_method => <<"SCRAM-SHA-224">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha256(client(), user_spec()) -> ok. auth_sasl_scram_sha256(Conn, Props) -> - Options = #{plus_variant => none, hash_type => sha256, xmpp_method => <<"SCRAM-SHA-256">>}, + Options = #{plus_variant => undefined, hash_type => sha256, xmpp_method => <<"SCRAM-SHA-256">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha384(client(), user_spec()) -> ok. auth_sasl_scram_sha384(Conn, Props) -> - Options = #{plus_variant => none, hash_type => sha384, xmpp_method => <<"SCRAM-SHA-384">>}, + Options = #{plus_variant => undefined, hash_type => sha384, xmpp_method => <<"SCRAM-SHA-384">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha512(client(), user_spec()) -> ok. auth_sasl_scram_sha512(Conn, Props) -> - Options = #{plus_variant => none, hash_type => sha512, xmpp_method => <<"SCRAM-SHA-512">>}, + Options = #{plus_variant => undefined, hash_type => sha512, xmpp_method => <<"SCRAM-SHA-512">>}, auth_sasl_scram(Options, Conn, Props). %% SCRAM PLUS -spec auth_sasl_scram_sha1_plus(client(), user_spec()) -> ok. auth_sasl_scram_sha1_plus(Conn, Props) -> - Options = #{plus_variant => tls_unique, hash_type => sha, + Options = #{plus_variant => tls_exporter, hash_type => sha, xmpp_method => <<"SCRAM-SHA-1-PLUS">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha224_plus(client(), user_spec()) -> ok. auth_sasl_scram_sha224_plus(Conn, Props) -> - Options = #{plus_variant => tls_unique, hash_type => sha224, + Options = #{plus_variant => tls_exporter, hash_type => sha224, xmpp_method => <<"SCRAM-SHA-224-PLUS">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha256_plus(client(), user_spec()) -> ok. auth_sasl_scram_sha256_plus(Conn, Props) -> - Options = #{plus_variant => tls_unique, hash_type => sha256, + Options = #{plus_variant => tls_exporter, hash_type => sha256, xmpp_method => <<"SCRAM-SHA-256-PLUS">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha384_plus(client(), user_spec()) -> ok. auth_sasl_scram_sha384_plus(Conn, Props) -> - Options = #{plus_variant => tls_unique, hash_type => sha384, + Options = #{plus_variant => tls_exporter, hash_type => sha384, xmpp_method => <<"SCRAM-SHA-384-PLUS">>}, auth_sasl_scram(Options, Conn, Props). -spec auth_sasl_scram_sha512_plus(client(), user_spec()) -> ok. auth_sasl_scram_sha512_plus(Conn, Props) -> - Options = #{plus_variant => tls_unique, hash_type => sha512, + Options = #{plus_variant => tls_exporter, hash_type => sha512, xmpp_method => <<"SCRAM-SHA-512-PLUS">>}, auth_sasl_scram(Options, Conn, Props). @@ -127,8 +130,7 @@ auth_sasl_scram(#{plus_variant := PlusVariant, Conn, Props) -> Username = get_property(username, Props), Password = get_property(password, Props), - ChannelBinding = scram_sha_auth_payload( - proplists:get_value(tls_module, Props, ssl), PlusVariant, Conn), + ChannelBinding = scram_sha_auth_payload(PlusVariant, Conn), {ok, ClientState1} = fast_scram:mech_new( #{entity => client, username => Username, hash_method => HashMethod, nonce_size => 16, channel_binding => ChannelBinding, auth_data => #{password => Password}}), @@ -220,14 +222,14 @@ md5_digest_response(ChallengeData, Props) -> {<<"authzid">>, FullJid} ])). -scram_sha_auth_payload(ssl, _, _) -> +scram_sha_auth_payload(undefined, _) -> {undefined, <<>>}; -scram_sha_auth_payload(fast_tls, none, _) -> +scram_sha_auth_payload(none, _) -> {none, <<>>}; -scram_sha_auth_payload(fast_tls, tls_unique, Conn) -> - {ok, FinishedTLS} = escalus_connection:get_tls_last_message(Conn), - {<<"tls-unique">>, FinishedTLS}. - +scram_sha_auth_payload(tls_exporter, Conn) -> + {ok, [Material | _]} = escalus_connection:export_key_materials( + Conn, [?CB_LABEL], [no_context], [32], true), + {<<"tls-exporter">>, Material}. hex_md5(Data) -> binary:encode_hex(crypto:hash(md5, Data), lowercase). diff --git a/src/escalus_connection.erl b/src/escalus_connection.erl index 1a69cbf..27508ea 100644 --- a/src/escalus_connection.erl +++ b/src/escalus_connection.erl @@ -31,7 +31,7 @@ get_sm_h/1, set_sm_h/2, set_filter_predicate/2, - get_tls_last_message/1, + export_key_materials/5, reset_parser/1, is_connected/1, wait_for_close/1, @@ -87,6 +87,16 @@ -callback set_filter_predicate(pid(), filter_pred()) -> ok. -callback stop(pid()) -> ok | already_stopped. -callback kill(pid()) -> ok | already_stopped. +-callback export_key_materials(pid(), Labels, Contexts, WantedLengths, ConsumeSecret) -> + {ok, ExportKeyMaterials} | + {error, undefined_tls_material | exporter_master_secret_already_consumed | bad_input} + when + Labels :: [binary()], + Contexts :: [binary() | no_context], + WantedLengths :: [non_neg_integer()], + ConsumeSecret :: boolean(), + ExportKeyMaterials :: binary() | [binary()]. +-optional_callbacks([export_key_materials/5]). -callback stream_start_req(user_spec()) -> exml_stream:element(). -callback stream_end_req(user_spec()) -> exml_stream:element(). @@ -390,11 +400,19 @@ set_sm_h(#client{module = Mod}, _) -> set_filter_predicate(#client{module = Module, rcv_pid = Pid}, Pred) -> Module:set_filter_predicate(Pid, Pred). --spec get_tls_last_message(client()) -> {ok, binary()} | {error, undefined_tls_message}. -get_tls_last_message(#client{module = escalus_tcp, rcv_pid = Pid}) -> - escalus_tcp:get_tls_last_message(Pid); -get_tls_last_message(#client{module = Mod}) -> - error({get_tls_last_message, {undefined_for_escalus_module, Mod}}). +-spec export_key_materials(client(), Labels, Contexts, WantedLengths, ConsumeSecret) -> + {ok, ExportKeyMaterials} | + {error, undefined_tls_material | exporter_master_secret_already_consumed | bad_input} + when + Labels :: [binary()], + Contexts :: [binary() | no_context], + WantedLengths :: [non_neg_integer()], + ConsumeSecret :: boolean(), + ExportKeyMaterials :: binary() | [binary()]. +export_key_materials(#client{module = escalus_tcp, rcv_pid = Pid}, Labels, Contexts, WantedLengths, ConsumeSecret) -> + escalus_tcp:export_key_materials(Pid, Labels, Contexts, WantedLengths, ConsumeSecret); +export_key_materials(#client{module = Mod}, _Labels, _Contexts, _WantedLengths, _ConsumeSecret) -> + error({export_key_materials, {undefined_for_escalus_module, Mod}}). -spec reset_parser(client()) -> ok. reset_parser(#client{module = Mod, rcv_pid = Pid}) -> diff --git a/src/escalus_session.erl b/src/escalus_session.erl index 90d5247..0708b01 100644 --- a/src/escalus_session.erl +++ b/src/escalus_session.erl @@ -78,16 +78,19 @@ authenticate(Client = #client{props = Props}) -> %% but as a default we use plain, as it incurrs lower load and better logs (no hashing) %% for common setups. If a different mechanism is required then it should be configured on the %% user specification. - {M, F} = proplists:get_value(auth, Props, {escalus_auth, auth_plain}), - PropsAfterAuth = case apply(M, F, [Client, Props]) of - ok -> Props; - {ok, P} when is_list(P) -> P - end, + PropsAfterAuth = apply_auth_method(Client, Props), escalus_connection:reset_parser(Client), Client1 = escalus_session:start_stream(Client#client{props = PropsAfterAuth}), escalus_session:stream_features(Client1, []), Client1. +apply_auth_method(Client, Props) -> + Fun = proplists:get_value(auth, Props, fun escalus_auth:auth_plain/2), + case apply(Fun, [Client, Props]) of + ok -> Props; + {ok, P} when is_list(P) -> P + end. + -spec bind(client()) -> client(). bind(Client = #client{props = Props0}) -> Resource = proplists:get_value(resource, Props0, ?DEFAULT_RESOURCE), diff --git a/src/escalus_tcp.erl b/src/escalus_tcp.erl index c0bc0aa..a2e0d4c 100644 --- a/src/escalus_tcp.erl +++ b/src/escalus_tcp.erl @@ -9,8 +9,6 @@ -behaviour(escalus_connection). -include_lib("exml/include/exml_stream.hrl"). --include_lib("exml/include/exml.hrl"). --include("escalus.hrl"). %% Escalus transport callbacks -export([connect/1, @@ -26,7 +24,7 @@ set_sm_h/2, is_using_compression/1, is_using_ssl/1, - get_tls_last_message/1 + export_key_materials/5 ]). %% Connection stream start and end callbacks -export([stream_start_req/1, @@ -57,7 +55,6 @@ -export_type([sm_state/0]). -define(WAIT_FOR_SOCKET_CLOSE_TIMEOUT, 1000). --define(SERVER, ?MODULE). -include("escalus_tcp.hrl"). -type state() :: #state{}. @@ -65,7 +62,6 @@ host => binary() | inet:ip_address() | inet:hostname(), port => pos_integer(), ssl => boolean(), - tls_module => ssl | fast_tls, stream_management => boolean(), manual_ack => boolean(), iface => inet:ip_address(), @@ -123,9 +119,17 @@ is_using_ssl(Pid) -> set_filter_predicate(Pid, Pred) -> gen_server:call(Pid, {set_filter_pred, Pred}). --spec get_tls_last_message(pid()) -> {ok, binary()} | {error, undefined_tls_message}. -get_tls_last_message(Pid) -> - gen_server:call(Pid, get_tls_last_message). +-spec export_key_materials(pid(), Labels, Contexts, WantedLengths, ConsumeSecret) -> + {ok, ExportKeyMaterials} | + {error, undefined_tls_material | exporter_master_secret_already_consumed | bad_input} + when + Labels :: [binary()], + Contexts :: [binary() | no_context], + WantedLengths :: [non_neg_integer()], + ConsumeSecret :: boolean(), + ExportKeyMaterials :: binary() | [binary()]. +export_key_materials(Pid, Labels, Contexts, WantedLengths, ConsumeSecret) -> + gen_server:call(Pid, {export_key_materials, {Labels, Contexts, WantedLengths, ConsumeSecret}}). -spec stop(pid()) -> ok | already_stopped. stop(Pid) -> @@ -201,7 +205,6 @@ set_active(Pid, Active) -> -spec init({opts(), pid()}) -> {ok, state()}. init({Opts, Owner}) -> #{ssl := IsSSLConnection, - tls_module := TLSMod, on_reply := OnReplyFun, on_request := OnRequestFun, parser_opts := ParserOpts, @@ -214,7 +217,6 @@ init({Opts, Owner}) -> socket = Socket, parser = Parser, ssl = IsSSLConnection, - tls_module = TLSMod, sm_state = SM, event_client = EventClient, on_reply = OnReplyFun, @@ -227,13 +229,11 @@ handle_call(get_sm_h, _From, #state{sm_state = {_, H, _}} = State) -> handle_call({set_sm_h, H}, _From, #state{sm_state = {A, _OldH, S}} = State) -> NewState = State#state{sm_state={A, H, S}}, {reply, {ok, H}, NewState}; -handle_call({upgrade_to_tls, SSLOpts}, _From, #state{socket = Socket, - tls_module = TLSMod} = State) -> - case tcp_to_tls(TLSMod, Socket, SSLOpts) of +handle_call({upgrade_to_tls, SSLOpts}, _From, #state{socket = Socket} = State) -> + case ssl:connect(Socket, SSLOpts) of {ok, TlsSocket} -> {ok, Parser} = exml_stream:new_parser(), - {reply, TlsSocket, State#state{socket = TlsSocket, parser = Parser, - ssl = true, tls_module = TLSMod}}; + {reply, TlsSocket, State#state{socket = TlsSocket, parser = Parser, ssl = true}}; {error, _} = E -> {reply, E, State} end; @@ -254,18 +254,14 @@ handle_call(get_ssl, _From, #state{ssl = false} = State) -> handle_call(get_ssl, _From, #state{ssl = _} = State) -> {reply, true, State}; handle_call({set_active, Active}, _From, State) -> - {reply, ok, set_active_opt(State,Active)}; + {reply, ok, set_active_opt(State, Active)}; handle_call({set_filter_pred, Pred}, _From, State) -> {reply, ok, State#state{filter_pred = Pred}}; -handle_call(get_tls_last_message, _From, - #state{socket = Socket, ssl = true, tls_module = fast_tls} = S) -> - Reply = fast_tls:get_tls_last_message(self, Socket), - {reply, Reply, S}; -handle_call(get_tls_last_message, _From, #state{} = S) -> - {reply, {error, undefined_tls_message}, S}; -handle_call(kill_connection, _, #state{socket = Socket, ssl = SSL, tls_module = TLSMod} = S) -> +handle_call({export_key_materials, Data}, _From, #state{socket = Socket, ssl = true} = S) -> + {reply, do_export_key_materials(Socket, Data), S}; +handle_call(kill_connection, _, #state{socket = Socket, ssl = SSL} = S) -> case SSL of - true -> TLSMod:close(Socket); + true -> ssl:close(Socket); false -> gen_tcp:close(Socket) end, close_compression_streams(S#state.compress), @@ -293,14 +289,9 @@ handle_cast(stop, State) -> handle_info({tcp, Socket, Data}, #state{socket = Socket, ssl = false} = State) -> NewState = handle_data(Socket, Data, State), {noreply, NewState}; -handle_info({ssl, Socket, Data}, #state{socket = Socket, ssl = true, tls_module = ssl} = State) -> +handle_info({ssl, Socket, Data}, #state{socket = Socket, ssl = true} = State) -> NewState = handle_data(Socket, Data, State), {noreply, NewState}; -handle_info({tcp, TcpSocket, Data}, #state{socket = {tlssock, TcpSocket, _} = TlsSocket, - ssl = true, tls_module = fast_tls} = State) -> - {ok, NewData} = fast_tls:recv_data(TlsSocket, Data), - NewState = handle_data(TlsSocket, NewData, State), - {noreply, NewState}; handle_info({tcp_closed, _Socket}, #state{} = State) -> {stop, normal, State}; handle_info({ssl_closed, _Socket}, #state{} = State) -> @@ -313,10 +304,12 @@ handle_info(_, State) -> {noreply, State}. -spec terminate(term(), state()) -> term(). -terminate(Reason, #state{socket = Socket, ssl = true, tls_module = TLSMod} = State) -> - common_terminate(TLSMod, Socket, Reason, State); -terminate(Reason, #state{socket = Socket} = State) -> - common_terminate(gen_tcp, Socket, Reason, State). +terminate(_Reason, #state{socket = Socket, ssl = true, parser = Parser}) -> + exml_stream:free_parser(Parser), + ssl:close(Socket); +terminate(_Reason, #state{socket = Socket, parser = Parser}) -> + exml_stream:free_parser(Parser), + gen_tcp:close(Socket). -spec code_change(term(), state(), term()) -> {ok, state()}. code_change(_OldVsn, State, _Extra) -> @@ -331,7 +324,6 @@ default_options() -> #{host => <<"localhost">>, port => 5222, ssl => false, - tls_module => ssl, stream_management => false, manual_ack => false, on_reply => fun(_) -> ok end, @@ -356,26 +348,25 @@ default_socket_options() -> %%% Helpers %%%=================================================================== set_active_opt( - #state{ssl = SSL, socket = Soc, tls_module = TLSMod} = State, Act) when is_boolean(Act) -> - set_active_opt(SSL, TLSMod, Soc, Act), + #state{ssl = SSL, socket = Soc} = State, Act) when is_boolean(Act) -> + set_active_opt(SSL, Soc, Act), State#state{active = Act}; set_active_opt( - #state{ssl = SSL, socket = Soc, active = Act, tls_module = TLSMod} = State, current_opt) -> - set_active_opt(SSL, TLSMod, Soc, Act), + #state{ssl = SSL, socket = Soc, active = Act} = State, current_opt) -> + set_active_opt(SSL, Soc, Act), State; -set_active_opt(#state{ssl = SSL, socket = Soc, tls_module = TLSMod} = State, once) -> - set_active_opt(SSL, TLSMod, Soc, true), +set_active_opt(#state{ssl = SSL, socket = Soc} = State, once) -> + set_active_opt(SSL, Soc, true), State#state{active = false}; -set_active_opt(#state{ssl = SSL, socket = Soc, tls_module = TLSMod} = State, at_least_once) -> - set_active_opt(SSL, TLSMod, Soc, true), +set_active_opt(#state{ssl = SSL, socket = Soc} = State, at_least_once) -> + set_active_opt(SSL, Soc, true), State. - -set_active_opt(true, TLSMod, Socket, true) -> - TLSMod:setopts(Socket, [{active, once}]); -set_active_opt(false, _, Socket, true) -> +set_active_opt(true, Socket, true) -> + ssl:setopts(Socket, [{active, once}]); +set_active_opt(false, Socket, true) -> inet:setopts(Socket, [{active, once}]); -set_active_opt(_,_,_,_) -> +set_active_opt(_, _, _) -> ok. handle_data(Socket, Data, #state{parser = Parser, @@ -435,15 +426,11 @@ maybe_compress_and_send(Data, #state{ compress = {zlib, {_, Zout}} } = State) -> maybe_compress_and_send(Data, State) -> raw_send(Data, State). -raw_send(Data, #state{socket = Socket, ssl = true, tls_module = TLSMod}) -> - TLSMod:send(Socket, Data); +raw_send(Data, #state{socket = Socket, ssl = true}) -> + ssl:send(Socket, Data); raw_send(Data, #state{socket = Socket}) -> gen_tcp:send(Socket, Data). -common_terminate(SocketModule, Socket, _Reason, #state{parser = Parser}) -> - exml_stream:free_parser(Parser), - SocketModule:close(Socket). - wait_until_closed(Socket) -> receive {tcp_closed, Socket} -> @@ -480,7 +467,6 @@ close_compression_streams({zlib, {Zin, Zout}}) -> end. do_connect(#{ssl := IsSSLConn, - tls_module := TLSMod, on_connect := OnConnectFun, host := Host, port := Port, @@ -489,7 +475,7 @@ do_connect(#{ssl := IsSSLConn, Address = host_to_inet(Host), SocketOpts = get_socket_opts(Opts), TimeB = erlang:system_time(microsecond), - Reply = maybe_ssl_connection(IsSSLConn, TLSMod, Address, Port, SocketOpts, SSLOpts, HibernateAfter), + Reply = maybe_ssl_connection(IsSSLConn, Address, Port, SocketOpts, SSLOpts, HibernateAfter), TimeA = erlang:system_time(microsecond), ConnectionTime = TimeA - TimeB, case Reply of @@ -500,30 +486,11 @@ do_connect(#{ssl := IsSSLConn, end, Reply. -maybe_ssl_connection(true, fast_tls, Address, Port, SocketOpts, SSLOpts, _) -> - {ok, GenTcpSocket} = gen_tcp:connect(Address, Port, SocketOpts), - tcp_to_tls(fast_tls, GenTcpSocket, SSLOpts); -maybe_ssl_connection(true, ssl, Address, Port, SocketOpts, SSLOpts, HibernateAfter) -> +maybe_ssl_connection(true, Address, Port, SocketOpts, SSLOpts, HibernateAfter) -> ssl:connect(Address, Port, SocketOpts ++ SSLOpts ++ [{hibernate_after, HibernateAfter}]); -maybe_ssl_connection(_, _, Address, Port, SocketOpts, _, _) -> +maybe_ssl_connection(_, Address, Port, SocketOpts, _, _) -> gen_tcp:connect(Address, Port, SocketOpts). -tcp_to_tls(fast_tls, GenTcpSocket, SSLOpts) -> - inet:setopts(GenTcpSocket, [{active, false}]), - case fast_tls:tcp_to_tls(GenTcpSocket, [connect | SSLOpts]) of - {ok, TlsSocket} -> - %% fast_tls requires dummy recv_data/2 call to accomplish TLS handshake - fast_tls:recv(TlsSocket, 0, 100), - fast_tls:recv_data(TlsSocket, <<>>), - fast_tls:recv_data(TlsSocket, <<>>), - fast_tls:setopts(TlsSocket, [{active, once}]), - {ok, TlsSocket}; - {error, _} = E -> - E - end; -tcp_to_tls(ssl, GenTcpSocket, SSLOpts) -> - ssl:connect(GenTcpSocket, SSLOpts). - %%=================================================================== %%% Init options parsing helpers %%%=================================================================== @@ -563,3 +530,20 @@ get_socket_opts(#{socket_opts := SocketOpts}) -> -spec opts_to_map([proplists:property()] | opts()) -> opts(). opts_to_map(Opts) when is_map(Opts) -> Opts; opts_to_map(Opts) when is_list(Opts) -> maps:from_list(Opts). + +-spec do_export_key_materials(ssl:sslsocket(), {Labels, Contexts, WantedLengths, ConsumeSecret}) -> + {ok, ExportKeyMaterials} | + {error, undefined_tls_material | exporter_master_secret_already_consumed | bad_input} + when + Labels :: [binary()], + Contexts :: [binary() | no_context], + WantedLengths :: [non_neg_integer()], + ConsumeSecret :: boolean(), + ExportKeyMaterials :: binary() | [binary()]. +-if(?OTP_RELEASE >= 27). +do_export_key_materials(SslSocket, {Labels, Contexts, WantedLengths, ConsumeSecret}) -> + ssl:export_key_materials(SslSocket, Labels, Contexts, WantedLengths, ConsumeSecret). +-else. +do_export_key_materials(_SslSocket, {_, _, _, _}) -> + {error, undefined_tls_material}. +-endif. diff --git a/src/escalus_tcp.hrl b/src/escalus_tcp.hrl index 9347549..a030079 100644 --- a/src/escalus_tcp.hrl +++ b/src/escalus_tcp.hrl @@ -15,11 +15,10 @@ %%============================================================================== -record(state, {owner, - socket :: gen_tcp:socket() | ssl:sslsocket() | fast_tls:tls_socket(), + socket :: gen_tcp:socket() | ssl:sslsocket(), parser, filter_pred, ssl = false, - tls_module = ssl, compress = false, event_client, on_reply, diff --git a/src/escalus_users.erl b/src/escalus_users.erl index 6efa64b..51a6f38 100644 --- a/src/escalus_users.erl +++ b/src/escalus_users.erl @@ -154,46 +154,50 @@ get_server(Config, User) -> get_wspath(Config, User) -> get_user_option(wspath, User, escalus_wspath, Config, undefined). --spec get_auth_method(escalus:config(), user()) -> {module(), atom()}. +-spec get_auth_method(escalus:config(), user()) -> + fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}). get_auth_method(Config, User) -> AuthMethod = get_user_option(auth_method, User, escalus_auth_method, Config, <<"PLAIN">>), get_auth_method(AuthMethod). --spec get_auth_method(binary() | {module(), atom()}) -> {module(), atom()}. +-spec get_auth_method(binary() | {module(), atom()}) -> + fun((escalus_connection:client(), escalus_users:user_spec()) -> ok | {ok, escalus_users:user_spec()}). get_auth_method(<<"PLAIN">>) -> - {escalus_auth, auth_plain}; + fun escalus_auth:auth_plain/2; get_auth_method(<<"DIGEST-MD5">>) -> - {escalus_auth, auth_digest_md5}; + fun escalus_auth:auth_digest_md5/2; get_auth_method(<<"SASL-ANON">>) -> - {escalus_auth, auth_sasl_anon}; + fun escalus_auth:auth_sasl_anon/2; %% SCRAM Regular get_auth_method(<<"SCRAM-SHA-1">>) -> - {escalus_auth, auth_sasl_scram_sha1}; + fun escalus_auth:auth_sasl_scram_sha1/2; get_auth_method(<<"SCRAM-SHA-224">>) -> - {escalus_auth, auth_sasl_scram_sha224}; + fun escalus_auth:auth_sasl_scram_sha224/2; get_auth_method(<<"SCRAM-SHA-256">>) -> - {escalus_auth, auth_sasl_scram_sha256}; + fun escalus_auth:auth_sasl_scram_sha256/2; get_auth_method(<<"SCRAM-SHA-384">>) -> - {escalus_auth, auth_sasl_scram_sha384}; + fun escalus_auth:auth_sasl_scram_sha384/2; get_auth_method(<<"SCRAM-SHA-512">>) -> - {escalus_auth, auth_sasl_scram_sha512}; + fun escalus_auth:auth_sasl_scram_sha512/2; %% SCRAM PLUS get_auth_method(<<"SCRAM-SHA-1-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha1_plus}; + fun escalus_auth:auth_sasl_scram_sha1_plus/2; get_auth_method(<<"SCRAM-SHA-224-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha224_plus}; + fun escalus_auth:auth_sasl_scram_sha224_plus/2; get_auth_method(<<"SCRAM-SHA-256-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha256_plus}; + fun escalus_auth:auth_sasl_scram_sha256_plus/2; get_auth_method(<<"SCRAM-SHA-384-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha384_plus}; + fun escalus_auth:auth_sasl_scram_sha384_plus/2; get_auth_method(<<"SCRAM-SHA-512-PLUS">>) -> - {escalus_auth, auth_sasl_scram_sha512_plus}; + fun escalus_auth:auth_sasl_scram_sha512_plus/2; get_auth_method(<<"X-OAUTH">>) -> - {escalus_auth, auth_sasl_oauth}; + fun escalus_auth:auth_sasl_oauth/2; get_auth_method({Mod, Fun}) when is_atom(Mod), is_atom(Fun) -> - {Mod, Fun}. + fun Mod:Fun/2; +get_auth_method(Fun) when is_function(Fun, 2) -> + Fun. -spec get_usp(escalus:config(), user()) -> [binary() | xmpp_domain()]. get_usp(Config, User) ->