diff --git a/src/hex_api.erl b/src/hex_api.erl index 6873421..dc85f29 100644 --- a/src/hex_api.erl +++ b/src/hex_api.erl @@ -17,7 +17,8 @@ -export_type([response/0]). -type response() :: {ok, {hex_http:status(), hex_http:headers(), body() | nil}} | {error, term()}. --type body() :: [body()] | #{binary() => body() | binary()}. +-type body() :: #{binary() => value()} | [#{binary() => value()}]. +-type value() :: binary() | boolean() | nil | number() | [value()] | #{binary() => value()}. %% @private get(Config, Path) -> diff --git a/src/hex_api_oauth.erl b/src/hex_api_oauth.erl index e4d903f..46e1e18 100644 --- a/src/hex_api_oauth.erl +++ b/src/hex_api_oauth.erl @@ -4,6 +4,8 @@ -export([ device_authorization/3, device_authorization/4, + device_auth_flow/4, + device_auth_flow/5, poll_device_token/3, refresh_token/3, revoke_token/3, @@ -11,6 +13,21 @@ client_credentials_token/5 ]). +-export_type([oauth_tokens/0, device_auth_error/0]). + +-type oauth_tokens() :: #{ + access_token := binary(), + refresh_token => binary() | undefined, + expires_at := integer() +}. + +-type device_auth_error() :: + timeout + | {access_denied, Status :: non_neg_integer(), Body :: term()} + | {device_auth_failed, Status :: non_neg_integer(), Body :: term()} + | {poll_failed, Status :: non_neg_integer(), Body :: term()} + | term(). + %% @doc %% Initiates the OAuth device authorization flow. %% @@ -26,7 +43,7 @@ device_authorization(Config, ClientId, Scope) -> %% Returns device code, user code, and verification URIs for user authentication. %% %% Options: -%% * `name' - A name to identify the token (e.g., hostname of the device) +%% * `name' - A name to identify the token (defaults to the machine's hostname) %% %% Examples: %% @@ -49,17 +66,141 @@ device_authorization(Config, ClientId, Scope) -> hex_api:response(). device_authorization(Config, ClientId, Scope, Opts) -> Path = <<"oauth/device_authorization">>, - Params0 = #{ - <<"client_id">> => ClientId, - <<"scope">> => Scope - }, - Params = + Name = case proplists:get_value(name, Opts) of - undefined -> Params0; - Name -> Params0#{<<"name">> => Name} + undefined -> get_hostname(); + N -> N end, + Params = #{ + <<"client_id">> => ClientId, + <<"scope">> => Scope, + <<"name">> => Name + }, hex_api:post(Config, Path, Params). +%% @doc +%% Runs the complete OAuth device authorization flow. +%% +%% @see device_auth_flow/5 +%% @end +-spec device_auth_flow( + hex_core:config(), + ClientId :: binary(), + Scope :: binary(), + PromptUser :: fun((VerificationUri :: binary(), UserCode :: binary()) -> ok) +) -> {ok, oauth_tokens()} | {error, device_auth_error()}. +device_auth_flow(Config, ClientId, Scope, PromptUser) -> + device_auth_flow(Config, ClientId, Scope, PromptUser, []). + +%% @doc +%% Runs the complete OAuth device authorization flow with options. +%% +%% This function handles the entire device authorization flow: +%% 1. Requests a device code from the server +%% 2. Calls `PromptUser' callback with the verification URI and user code +%% 3. Optionally opens the browser for the user (when `open_browser' is true) +%% 4. Polls the token endpoint until authorization completes or times out +%% +%% The `PromptUser' callback is responsible for displaying the verification URI +%% and user code to the user (e.g., printing to console). +%% +%% Options: +%% * `name' - A name to identify the token (defaults to the machine's hostname) +%% * `open_browser' - When `true', automatically opens the browser +%% to the verification URI. When `false' (default), only the callback is invoked. +%% +%% Returns: +%% - `{ok, Tokens}' - Authorization successful, returns access token and optional refresh token +%% - `{error, timeout}' - Device code expired before user completed authorization +%% - `{error, {access_denied, Status, Body}}' - User denied the authorization request +%% - `{error, {device_auth_failed, Status, Body}}' - Initial device authorization request failed +%% - `{error, {poll_failed, Status, Body}}' - Unexpected error during polling +%% +%% Examples: +%% +%% ``` +%% 1> Config = hex_core:default_config(). +%% 2> PromptUser = fun(Uri, Code) -> +%% io:format("Visit ~s and enter code: ~s~n", [Uri, Code]) +%% end. +%% 3> hex_api_oauth:device_auth_flow(Config, <<"cli">>, <<"api:write">>, PromptUser). +%% {ok, #{ +%% access_token => <<"...">>, +%% refresh_token => <<"...">>, +%% expires_at => 1234567890 +%% }} +%% ''' +%% @end +-spec device_auth_flow( + hex_core:config(), + ClientId :: binary(), + Scope :: binary(), + PromptUser :: fun((VerificationUri :: binary(), UserCode :: binary()) -> ok), + proplists:proplist() +) -> {ok, oauth_tokens()} | {error, device_auth_error()}. +device_auth_flow(Config, ClientId, Scope, PromptUser, Opts) -> + case device_authorization(Config, ClientId, Scope, Opts) of + {ok, {200, _, DeviceResponse}} when is_map(DeviceResponse) -> + #{ + <<"device_code">> := DeviceCode, + <<"user_code">> := UserCode, + <<"verification_uri_complete">> := VerificationUri, + <<"expires_in">> := ExpiresIn, + <<"interval">> := IntervalSeconds + } = DeviceResponse, + ok = PromptUser(VerificationUri, UserCode), + OpenBrowser = proplists:get_value(open_browser, Opts, false), + case OpenBrowser of + true -> open_browser(VerificationUri); + false -> ok + end, + ExpiresAt = erlang:system_time(second) + ExpiresIn, + poll_for_token_loop(Config, ClientId, DeviceCode, IntervalSeconds, ExpiresAt); + {ok, {Status, _, Body}} -> + {error, {device_auth_failed, Status, Body}}; + {error, Reason} -> + {error, Reason} + end. + +%% @private +poll_for_token_loop(Config, ClientId, DeviceCode, IntervalSeconds, ExpiresAt) -> + Now = erlang:system_time(second), + case Now >= ExpiresAt of + true -> + {error, timeout}; + false -> + timer:sleep(IntervalSeconds * 1000), + case poll_device_token(Config, ClientId, DeviceCode) of + {ok, {200, _, TokenResponse}} when is_map(TokenResponse) -> + #{ + <<"access_token">> := AccessToken, + <<"expires_in">> := ExpiresIn + } = TokenResponse, + RefreshToken = maps:get(<<"refresh_token">>, TokenResponse, undefined), + TokenExpiresAt = erlang:system_time(second) + ExpiresIn, + {ok, #{ + access_token => AccessToken, + refresh_token => RefreshToken, + expires_at => TokenExpiresAt + }}; + {ok, {400, _, #{<<"error">> := <<"authorization_pending">>}}} -> + poll_for_token_loop(Config, ClientId, DeviceCode, IntervalSeconds, ExpiresAt); + {ok, {400, _, #{<<"error">> := <<"slow_down">>}}} -> + %% Increase polling interval as requested by server + poll_for_token_loop( + Config, ClientId, DeviceCode, IntervalSeconds + 5, ExpiresAt + ); + {ok, {400, _, #{<<"error">> := <<"expired_token">>}}} -> + {error, timeout}; + {ok, {Status, _, #{<<"error">> := <<"access_denied">>} = Body}} -> + {error, {access_denied, Status, Body}}; + {ok, {Status, _, Body}} -> + {error, {poll_failed, Status, Body}}; + {error, Reason} -> + {error, Reason} + end + end. + %% @doc %% Polls the OAuth token endpoint for device authorization completion. %% @@ -199,3 +340,44 @@ revoke_token(Config, ClientId, Token) -> <<"client_id">> => ClientId }, hex_api:post(Config, Path, Params). + +%%==================================================================== +%% Internal functions +%%==================================================================== + +%% @private +%% Open a URL in the default browser. +%% Uses platform-specific commands: open (macOS), xdg-open (Linux), start (Windows). +-spec open_browser(binary()) -> ok. +open_browser(Url) when is_binary(Url) -> + ok = ensure_valid_http_url(Url), + UrlStr = binary_to_list(Url), + {Cmd, Args} = + case os:type() of + {unix, darwin} -> + {"open", [UrlStr]}; + {unix, _} -> + {"xdg-open", [UrlStr]}; + {win32, _} -> + {"cmd", ["/c", "start", "", UrlStr]} + end, + Port = open_port({spawn_executable, os:find_executable(Cmd)}, [{args, Args}]), + port_close(Port), + ok. + +%% @private +%% Validates that a URL uses http:// or https:// scheme. +-spec ensure_valid_http_url(binary()) -> ok. +ensure_valid_http_url(Url) when is_binary(Url) -> + case uri_string:parse(Url) of + #{scheme := <<"https">>} -> ok; + #{scheme := <<"http">>} -> ok; + _ -> throw({invalid_url, Url}) + end. + +%% @private +%% Get the hostname of the current machine. +-spec get_hostname() -> binary(). +get_hostname() -> + {ok, Hostname} = inet:gethostname(), + list_to_binary(Hostname). diff --git a/src/hex_cli_auth.erl b/src/hex_cli_auth.erl new file mode 100644 index 0000000..05143ab --- /dev/null +++ b/src/hex_cli_auth.erl @@ -0,0 +1,703 @@ +%% @doc +%% Authentication handling with callback functions for build-tool-specific operations. +%% +%% This module provides generic authentication handling that allows both rebar3 +%% and Elixir Hex (and future build tools) to share the common auth logic while +%% customizing prompting, persistence, and configuration retrieval. +%% +%% == Callbacks == +%% +%% The caller provides a callbacks map with these functions (all required): +%% +%% ``` +%% #{ +%% %% Auth configuration for a specific repo +%% get_auth_config => fun((RepoName :: binary()) -> +%% #{api_key => binary(), +%% auth_key => binary(), +%% oauth_exchange => boolean(), +%% oauth_exchange_url => binary()} | undefined), +%% +%% %% Global OAuth tokens - storage and retrieval +%% get_oauth_tokens => fun(() -> {ok, #{access_token := binary(), +%% refresh_token => binary(), +%% expires_at := integer()}} | error), +%% persist_oauth_tokens => fun((Scope :: global | binary(), +%% AccessToken :: binary(), +%% RefreshToken :: binary() | undefined, +%% ExpiresAt :: integer()) -> ok), +%% +%% %% User interaction +%% prompt_otp => fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled), +%% should_authenticate => fun((Reason :: no_credentials | token_refresh_failed) -> boolean()), +%% +%% %% OAuth client configuration +%% get_client_id => fun(() -> binary()) +%% } +%% ''' +%% +%% == Auth Resolution Order == +%% +%% For API calls: +%%
    +%%
  1. Per-repo `api_key' from config (with optional OAuth exchange for hex.pm)
  2. +%%
  3. Parent repo `api_key' (for "hexpm:org" organizations)
  4. +%%
  5. Global OAuth token (refreshed if expired)
  6. +%%
  7. Device auth flow (for write operations only)
  8. +%%
+%% +%% For repo calls: +%%
    +%%
  1. Per-repo `auth_key' with optional OAuth exchange (default true for hex.pm)
  2. +%%
  3. Parent repo `auth_key'
  4. +%%
  5. Global OAuth token
  6. +%%
+%% +%% == OAuth Exchange == +%% +%% For hex.pm URLs, `api_key' and `auth_key' are exchanged for short-lived OAuth +%% tokens via the client credentials grant. This behavior can be controlled per-repo +%% via the `oauth_exchange' option in the repo config (defaults to `true' for hex.pm). +%% +%% == Auth Context == +%% +%% Internally, authentication resolution tracks context via `auth_context()': +%% +%% +%% == Token Format == +%% +%% OAuth access tokens are automatically prefixed with `<<"Bearer ">>' when used +%% as `api_key' or `repo_key' in the config. +-module(hex_cli_auth). + +-export([ + with_api/4, + with_api/5, + with_repo/3, + with_repo/4, + resolve_api_auth/3, + resolve_repo_auth/2 +]). + +-export_type([ + callbacks/0, + permission/0, + auth_error/0, + auth_context/0, + repo_auth_config/0, + auth_prompt_reason/0, + opts/0 +]). + +%% 5 minute buffer before expiry +-define(EXPIRY_BUFFER_SECONDS, 300). + +%% Maximum OTP retry attempts +-define(MAX_OTP_RETRIES, 3). + +-type permission() :: read | write. + +-type callbacks() :: #{ + get_auth_config := fun((RepoName :: binary()) -> repo_auth_config() | undefined), + get_oauth_tokens := fun(() -> {ok, oauth_tokens()} | error), + persist_oauth_tokens := fun( + ( + Scope :: global | binary(), + AccessToken :: binary(), + RefreshToken :: binary() | undefined, + ExpiresAt :: integer() + ) -> ok + ), + prompt_otp := fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled), + should_authenticate := fun((Reason :: auth_prompt_reason()) -> boolean()), + get_client_id := fun(() -> binary()) +}. + +-type auth_prompt_reason() :: + no_credentials + | token_refresh_failed. + +-type repo_auth_config() :: #{ + api_key => binary(), + repo_key => binary(), + auth_key => binary(), + oauth_token => oauth_tokens() +}. + +-type oauth_tokens() :: #{ + access_token := binary(), + refresh_token => binary(), + expires_at := integer() +}. + +-type auth_error() :: + {auth_error, no_credentials} + | {auth_error, otp_cancelled} + | {auth_error, otp_max_retries} + | {auth_error, token_refresh_failed} + | {auth_error, device_auth_timeout} + | {auth_error, device_auth_denied} + | {auth_error, oauth_exchange_failed} + | {auth_error, term()}. + +-type auth_context() :: #{ + source => env | config | oauth, + has_refresh_token => boolean() +}. + +-type opts() :: [ + {optional, boolean()} + | {auth_inline, boolean()} + | {oauth_open_browser, boolean()} +]. + +%%==================================================================== +%% API functions +%%==================================================================== + +%% @doc +%% Execute a function with API authentication. +%% +%% Equivalent to `with_api(Callbacks, Permission, Config, Fun, [])'. +%% +%% @see with_api/5 +-spec with_api(callbacks(), permission(), hex_core:config(), fun((hex_core:config()) -> Result)) -> + Result | {error, auth_error()} +when + Result :: term(). +with_api(Callbacks, Permission, BaseConfig, Fun) -> + with_api(Callbacks, Permission, BaseConfig, Fun, []). + +%% @doc +%% Execute a function with API authentication. +%% +%% Resolves credentials in this order: +%%
    +%%
  1. Per-repo `api_key' from config (with optional OAuth exchange for hex.pm)
  2. +%%
  3. Parent repo `api_key' (for "hexpm:org" organizations)
  4. +%%
  5. Global OAuth token (refreshed if expired)
  6. +%%
  7. Device auth flow (when `should_authenticate' callback returns true)
  8. +%%
+%% +%% On 401 responses, handles OTP prompts and token refresh automatically. +%% +%% The repository name is taken from the config (`repo_name' or `repo_organization'). +%% +%% Options: +%% +%% +%% Example: +%% ``` +%% hex_cli_auth:with_api(Callbacks, write, Config, fun(C) -> +%% hex_api_release:publish(C, Tarball) +%% end, [{optional, false}, {auth_inline, true}]). +%% ''' +-spec with_api( + callbacks(), + permission(), + hex_core:config(), + fun((hex_core:config()) -> Result), + opts() +) -> + Result | {error, auth_error()} +when + Result :: term(). +with_api(Callbacks, Permission, BaseConfig, Fun, Opts) -> + Optional = proplists:get_value(optional, Opts, false), + AuthInline = proplists:get_value(auth_inline, Opts, true), + case resolve_api_auth(Callbacks, Permission, BaseConfig) of + {ok, ApiKey, AuthContext} -> + Config = BaseConfig#{api_key => ApiKey}, + execute_with_retry(Callbacks, Config, Fun, AuthContext, 0, undefined); + {error, no_auth} when Optional =:= true -> + %% Auth is optional, try without credentials first + execute_optional_with_retry(Callbacks, BaseConfig, Fun, Opts); + {error, no_auth} when AuthInline =:= true -> + %% No auth found, ask user if they want to authenticate + maybe_authenticate_and_retry(Callbacks, BaseConfig, Fun, no_credentials, Opts); + {error, no_auth} -> + %% auth_inline is false, just return error + {error, {auth_error, no_credentials}}; + {error, _} = Error -> + Error + end. + +%% @doc +%% Execute a function with repository authentication. +%% +%% Equivalent to `with_repo(Callbacks, Config, Fun, [])'. +%% +%% @see with_repo/4 +-spec with_repo(callbacks(), hex_core:config(), fun((hex_core:config()) -> Result)) -> + Result | {error, auth_error()} +when + Result :: term(). +with_repo(Callbacks, BaseConfig, Fun) -> + with_repo(Callbacks, BaseConfig, Fun, []). + +%% @doc +%% Execute a function with repository authentication. +%% +%% Resolves credentials in this order: +%%
    +%%
  1. `repo_key' in config - passthrough
  2. +%%
  3. `repo_key' from `get_auth_config' callback - passthrough
  4. +%%
  5. `auth_key' from `get_auth_config' when `trusted' is true and `oauth_exchange' is true - exchange for OAuth token
  6. +%%
  7. `auth_key' from `get_auth_config' when `trusted' is true - use directly
  8. +%%
  9. Global OAuth token from `get_oauth_tokens' callback
  10. +%%
  11. No auth when `optional' is true (with retry on 401)
  12. +%%
  13. Prompt via `should_authenticate' when `auth_inline' is true
  14. +%%
+%% +%% The repository name is taken from the config (`repo_name' or `repo_organization'). +%% +%% Options: +%% +%% +%% Example: +%% ``` +%% hex_cli_auth:with_repo(Callbacks, Config, fun(C) -> +%% hex_repo:get_tarball(C, <<"ecto">>, <<"3.0.0">>) +%% end). +%% ''' +-spec with_repo( + callbacks(), hex_core:config(), fun((hex_core:config()) -> Result), opts() +) -> + Result | {error, auth_error()} +when + Result :: term(). +with_repo(Callbacks, BaseConfig, Fun, Opts) -> + Optional = proplists:get_value(optional, Opts, true), + AuthInline = proplists:get_value(auth_inline, Opts, false), + case resolve_repo_auth(Callbacks, BaseConfig) of + {ok, RepoKey, _AuthContext} when is_binary(RepoKey) -> + Config = BaseConfig#{repo_key => RepoKey}, + Fun(Config); + no_auth when Optional =:= true -> + %% Auth is optional, try without credentials first + execute_optional_with_retry(Callbacks, BaseConfig, Fun, Opts); + no_auth when AuthInline =:= true -> + %% No auth found, ask user if they want to authenticate + maybe_authenticate_and_retry(Callbacks, BaseConfig, Fun, no_credentials, Opts); + no_auth -> + %% auth_inline is false, return error + {error, {auth_error, no_credentials}}; + {error, _} = Error -> + Error + end. + +%% @private +%% Extract repository name from config. +-spec repo_name(hex_core:config()) -> binary(). +repo_name(#{repo_name := Name, repo_organization := Org}) when is_binary(Name) and is_binary(Org) -> + <>; +repo_name(#{repo_name := Name}) when is_binary(Name) -> Name; +repo_name(_) -> + <<"hexpm">>. + +%% @private +%% Ask user if they want to authenticate, and if yes, initiate device auth. +maybe_authenticate_and_retry(Callbacks, BaseConfig, Fun, Reason, Opts) -> + case call_callback(Callbacks, should_authenticate, [Reason]) of + true -> + case device_auth(Callbacks, BaseConfig, <<"api repositories">>, Opts) of + {ok, #{access_token := Token}} -> + BearerToken = <<"Bearer ", Token/binary>>, + Config = BaseConfig#{api_key => BearerToken}, + AuthContext = #{source => oauth, has_refresh_token => true}, + execute_with_retry(Callbacks, Config, Fun, AuthContext, 0, undefined); + {error, _} = Error -> + Error + end; + false -> + {error, {auth_error, no_credentials}} + end. + +%% @private +%% Execute function without auth, but retry with auth if we get a 401. +execute_optional_with_retry(Callbacks, BaseConfig, Fun, Opts) -> + AuthInline = proplists:get_value(auth_inline, Opts, true), + case Fun(BaseConfig) of + {ok, {401, _Headers, _Body}} when AuthInline =:= true -> + %% Got 401, need auth - ask user if they want to authenticate + maybe_authenticate_and_retry(Callbacks, BaseConfig, Fun, no_credentials, Opts); + {ok, {401, _Headers, _Body}} -> + %% Got 401 but auth_inline is false, return error + {error, {auth_error, no_credentials}}; + Other -> + Other + end. + +%%==================================================================== +%% Internal functions - Device Auth +%%==================================================================== + +%% @private +%% Initiate OAuth device authorization flow. +%% Prompts user, optionally opens the browser for user authentication, +%% polls for token completion, and persists tokens via callback on success. +-spec device_auth(callbacks(), hex_core:config(), binary(), opts()) -> + {ok, oauth_tokens()} | {error, auth_error()}. +device_auth(Callbacks, Config, Scope, Opts) -> + ClientId = call_callback(Callbacks, get_client_id, []), + OpenBrowser = proplists:get_value(oauth_open_browser, Opts, true), + PromptUser = fun(VerificationUri, UserCode) -> + io:format("Open ~ts in your browser and enter code: ~ts~n", [VerificationUri, UserCode]) + end, + FlowOpts = [{open_browser, OpenBrowser}], + case hex_api_oauth:device_auth_flow(Config, ClientId, Scope, PromptUser, FlowOpts) of + {ok, #{access_token := AccessToken, refresh_token := RefreshToken, expires_at := ExpiresAt}} -> + ok = call_callback(Callbacks, persist_oauth_tokens, [ + global, AccessToken, RefreshToken, ExpiresAt + ]), + {ok, #{ + access_token => AccessToken, + refresh_token => RefreshToken, + expires_at => ExpiresAt + }}; + {error, timeout} -> + {error, {auth_error, device_auth_timeout}}; + {error, {access_denied, _Status, _Body}} -> + {error, {auth_error, device_auth_denied}}; + {error, {device_auth_failed, _Status, _Body} = Reason} -> + {error, {auth_error, Reason}}; + {error, {poll_failed, _Status, _Body} = Reason} -> + {error, {auth_error, Reason}}; + {error, Reason} -> + {error, {auth_error, Reason}} + end. + +%% @private +%% Check if a token is expired (within 5 minute buffer). +-spec is_token_expired(integer()) -> boolean(). +is_token_expired(ExpiresAt) -> + Now = erlang:system_time(second), + ExpiresAt - Now < ?EXPIRY_BUFFER_SECONDS. + +%%==================================================================== +%% Internal functions - Auth Resolution +%%==================================================================== + +%% @private +-spec resolve_api_auth(callbacks(), permission(), hex_core:config()) -> + {ok, binary(), auth_context()} | {error, no_auth} | {error, auth_error()}. +resolve_api_auth(_Callbacks, _Permission, #{api_key := ApiKey}) when is_binary(ApiKey) -> + %% api_key already in config, pass through directly + {ok, ApiKey, #{source => config, has_refresh_token => false}}; +resolve_api_auth(Callbacks, _Permission, Config) -> + RepoName = repo_name(Config), + %% 1. Check per-repo api_key + case call_callback(Callbacks, get_auth_config, [RepoName]) of + #{api_key := ApiKey} when is_binary(ApiKey) -> + {ok, ApiKey, #{source => config, has_refresh_token => false}}; + _ -> + %% 2. Check parent repo (for "hexpm:org" organizations) + case get_parent_repo_key(Callbacks, RepoName, api_key) of + {ok, ApiKey} -> + {ok, ApiKey, #{source => config, has_refresh_token => false}}; + error -> + %% 3. Try global OAuth token + resolve_oauth_token_with_context(Callbacks, Config) + end + end. + +%% @private +%% Resolve repo auth credentials in this order: +%% 0. repo_key in config => passthrough +%% 1. repo_key from get_auth_config => passthrough +%% 2. trusted + auth_key + oauth_exchange => exchange for OAuth token +%% 3. trusted + auth_key => use directly +%% 4. trusted + global OAuth tokens => use those +%% 5. Fallthrough to no_auth (handled by with_repo/4 for optional/auth_inline) +-spec resolve_repo_auth(callbacks(), hex_core:config()) -> + {ok, binary(), auth_context()} | no_auth | {error, auth_error()}. +resolve_repo_auth(_Callbacks, #{repo_key := RepoKey}) when is_binary(RepoKey) -> + %% repo_key already in config, pass through directly + {ok, RepoKey, #{source => config, has_refresh_token => false}}; +resolve_repo_auth(Callbacks, Config) -> + RepoName = repo_name(Config), + global:trans( + {{?MODULE, repo}, RepoName}, + fun() -> + do_resolve_repo_auth(Callbacks, RepoName, RepoName, Config) + end, + [], + infinity + ). + +do_resolve_repo_auth(Callbacks, RepoName, LookupRepo, Config) -> + Trusted = maps:get(trusted, Config, false), + OAuthExchange = maps:get(oauth_exchange, Config, false), + case call_callback(Callbacks, get_auth_config, [LookupRepo]) of + #{repo_key := RepoKey} when is_binary(RepoKey) -> + %% 1. repo_key from get_auth_config => passthrough + {ok, RepoKey, #{source => config, has_refresh_token => false}}; + #{oauth_token := OAuthToken, auth_key := AuthKey} when + is_binary(AuthKey) and OAuthExchange, Trusted + -> + %% 2. trusted + oauth_token + auth_key + oauth_exchange => use/refresh existing token + resolve_repo_oauth_token(Callbacks, RepoName, Config, AuthKey, OAuthToken); + #{auth_key := AuthKey} when is_binary(AuthKey) and OAuthExchange, Trusted -> + %% 3. trusted + auth_key + oauth_exchange => exchange for new OAuth token + exchange_for_oauth_token(Callbacks, RepoName, Config, AuthKey, <<"repositories">>); + #{auth_key := AuthKey} when is_binary(AuthKey), Trusted -> + %% 4. trusted + auth_key => use directly + {ok, AuthKey, #{source => config, has_refresh_token => false}}; + _ when Trusted -> + %% 5. Check parent repo (for "hexpm:org" organizations) + case binary:split(LookupRepo, <<":">>) of + [ParentName, _OrgName] -> + do_resolve_repo_auth(Callbacks, RepoName, ParentName, Config); + _ -> + %% 6. trusted + global OAuth tokens => use those + resolve_global_oauth_for_repo(Callbacks, Config) + end; + _ -> + %% 7. Not trusted, no auth + no_auth + end. + +%% @private +resolve_global_oauth_for_repo(Callbacks, Config) -> + case resolve_oauth_token_with_context(Callbacks, Config) of + {ok, Token, AuthContext} -> + {ok, Token, AuthContext}; + {error, no_auth} -> + no_auth; + {error, _} = Error -> + Error + end. + +%% @private +%% Resolve repo OAuth token: use if valid, re-exchange if expiring. +resolve_repo_oauth_token(Callbacks, RepoName, Config, AuthKey, #{ + access_token := AccessToken, expires_at := ExpiresAt +}) -> + case is_token_expired(ExpiresAt) of + false -> + %% Token is still valid, use it + BearerToken = <<"Bearer ", AccessToken/binary>>, + {ok, BearerToken, #{source => oauth, has_refresh_token => false}}; + true -> + %% Token expired, do a new exchange + exchange_for_oauth_token(Callbacks, RepoName, Config, AuthKey, <<"repositories">>) + end. + +%% @private +%% Exchange api_key/auth_key for OAuth token via client credentials grant. +%% Persists the token with the repo name for per-repo token storage. +exchange_for_oauth_token(Callbacks, RepoName, Config, AuthKey, Scope) -> + ClientId = call_callback(Callbacks, get_client_id, []), + ExchangeConfig = + case maps:get(oauth_exchange_url, Config, undefined) of + undefined -> Config; + OAuthUrl -> Config#{api_url => OAuthUrl} + end, + case hex_api_oauth:client_credentials_token(ExchangeConfig, ClientId, AuthKey, Scope) of + {ok, {200, _, #{<<"access_token">> := AccessToken, <<"expires_in">> := ExpiresIn}}} -> + ExpiresAt = erlang:system_time(second) + ExpiresIn, + ok = call_callback(Callbacks, persist_oauth_tokens, [ + RepoName, AccessToken, undefined, ExpiresAt + ]), + BearerToken = <<"Bearer ", AccessToken/binary>>, + {ok, BearerToken, #{source => oauth, has_refresh_token => false}}; + {ok, {_Status, _, _Body}} -> + {error, {auth_error, oauth_exchange_failed}}; + {error, _} -> + {error, {auth_error, oauth_exchange_failed}} + end. + +%% @private +get_parent_repo_key(Callbacks, RepoName, KeyType) -> + case binary:split(RepoName, <<":">>) of + [ParentName, _OrgName] -> + case call_callback(Callbacks, get_auth_config, [ParentName]) of + #{KeyType := Key} when is_binary(Key) -> + {ok, Key}; + _ -> + error + end; + _ -> + error + end. + +%% @private +%% Resolve OAuth token with global lock to prevent concurrent refresh attempts. +resolve_oauth_token_with_context(Callbacks, Config) -> + global:trans( + {{?MODULE, token_refresh}, self()}, + fun() -> + do_resolve_oauth_token_with_context(Callbacks, Config) + end, + [], + infinity + ). + +%% @private +do_resolve_oauth_token_with_context(Callbacks, Config) -> + case call_callback(Callbacks, get_oauth_tokens, []) of + {ok, #{access_token := AccessToken, expires_at := ExpiresAt} = Tokens} -> + HasRefreshToken = + maps:is_key(refresh_token, Tokens) andalso + is_binary(maps:get(refresh_token, Tokens)), + case is_token_expired(ExpiresAt) of + true -> + maybe_refresh_token_with_context(Callbacks, Config, Tokens); + false -> + BearerToken = <<"Bearer ", AccessToken/binary>>, + {ok, BearerToken, #{source => oauth, has_refresh_token => HasRefreshToken}} + end; + error -> + {error, no_auth} + end. + +%% @private +maybe_refresh_token_with_context(Callbacks, Config, #{refresh_token := RefreshToken}) when + is_binary(RefreshToken) +-> + ClientId = call_callback(Callbacks, get_client_id, []), + case hex_api_oauth:refresh_token(Config, ClientId, RefreshToken) of + {ok, {200, _, TokenResponse}} when is_map(TokenResponse) -> + #{ + <<"access_token">> := NewAccessToken, + <<"expires_in">> := ExpiresIn + } = TokenResponse, + NewRefreshToken = maps:get(<<"refresh_token">>, TokenResponse, RefreshToken), + ExpiresAt = erlang:system_time(second) + ExpiresIn, + ok = call_callback(Callbacks, persist_oauth_tokens, [ + global, NewAccessToken, NewRefreshToken, ExpiresAt + ]), + BearerToken = <<"Bearer ", NewAccessToken/binary>>, + HasRefreshToken = is_binary(NewRefreshToken), + {ok, BearerToken, #{source => oauth, has_refresh_token => HasRefreshToken}}; + {ok, {_Status, _, _Body}} -> + {error, {auth_error, token_refresh_failed}}; + {error, _Reason} -> + {error, {auth_error, token_refresh_failed}} + end; +maybe_refresh_token_with_context(_Callbacks, _Config, _Tokens) -> + {error, {auth_error, token_refresh_failed}}. + +%%==================================================================== +%% Internal functions - Retry Logic +%%==================================================================== + +%% @private +execute_with_retry(Callbacks, Config, Fun, AuthContext, OtpRetries, LastOtpError) -> + case Fun(Config) of + {error, otp_required} -> + handle_otp_retry( + Callbacks, Config, Fun, AuthContext, OtpRetries, <<"Enter OTP code:">> + ); + {error, invalid_totp} -> + handle_otp_retry( + Callbacks, + Config, + Fun, + AuthContext, + OtpRetries, + <<"Invalid OTP code. Please try again:">> + ); + {ok, {401, Headers, _Body}} = Response -> + case detect_auth_error(Headers) of + otp_required -> + handle_otp_retry( + Callbacks, Config, Fun, AuthContext, OtpRetries, <<"Enter OTP code:">> + ); + invalid_totp -> + Msg = + case LastOtpError of + invalid_totp -> <<"Invalid OTP code. Please try again:">>; + _ -> <<"Enter OTP code:">> + end, + handle_otp_retry(Callbacks, Config, Fun, AuthContext, OtpRetries, Msg); + token_expired -> + handle_token_refresh_retry(Callbacks, Config, Fun, AuthContext); + none -> + Response + end; + Other -> + Other + end. + +%% @private +handle_otp_retry(_Callbacks, _Config, _Fun, _AuthContext, OtpRetries, _Message) when + OtpRetries >= ?MAX_OTP_RETRIES +-> + {error, {auth_error, otp_max_retries}}; +handle_otp_retry(Callbacks, Config, Fun, AuthContext, OtpRetries, Message) -> + case call_callback(Callbacks, prompt_otp, [Message]) of + {ok, OtpCode} -> + NewConfig = Config#{api_otp => OtpCode}, + execute_with_retry( + Callbacks, NewConfig, Fun, AuthContext, OtpRetries + 1, invalid_totp + ); + cancelled -> + {error, {auth_error, otp_cancelled}} + end. + +%% @private +handle_token_refresh_retry(Callbacks, Config, Fun, AuthContext) -> + %% Only attempt refresh if we have a refresh token + case maps:get(has_refresh_token, AuthContext, false) of + true -> + case resolve_oauth_token_with_context(Callbacks, Config) of + {ok, NewBearerToken, NewAuthContext} -> + NewConfig = Config#{api_key => NewBearerToken}, + execute_with_retry(Callbacks, NewConfig, Fun, NewAuthContext, 0, undefined); + {error, _} -> + {error, {auth_error, token_refresh_failed}} + end; + false -> + {error, {auth_error, token_refresh_failed}} + end. + +%% @private +-spec detect_auth_error(hex_http:headers()) -> otp_required | invalid_totp | token_expired | none. +detect_auth_error(Headers) -> + case maps:get(<<"www-authenticate">>, Headers, undefined) of + undefined -> + none; + Value -> + parse_www_authenticate(Value) + end. + +%% @private +parse_www_authenticate(Value) when is_binary(Value) -> + case Value of + <<"Bearer realm=\"hex\", error=\"totp_required\"", _/binary>> -> + otp_required; + <<"Bearer realm=\"hex\", error=\"invalid_totp\"", _/binary>> -> + invalid_totp; + <<"Bearer realm=\"hex\", error=\"token_expired\"", _/binary>> -> + token_expired; + _ -> + none + end. + +%%==================================================================== +%% Internal functions - Utilities +%%==================================================================== + +%% @private +call_callback(Callbacks, Name, Args) -> + Fun = maps:get(Name, Callbacks), + erlang:apply(Fun, Args). diff --git a/src/hex_core.erl b/src/hex_core.erl index 1047fb0..221c463 100644 --- a/src/hex_core.erl +++ b/src/hex_core.erl @@ -111,7 +111,10 @@ tarball_max_size => pos_integer() | infinity, tarball_max_uncompressed_size => pos_integer() | infinity, docs_tarball_max_size => pos_integer() | infinity, - docs_tarball_max_uncompressed_size => pos_integer() | infinity + docs_tarball_max_uncompressed_size => pos_integer() | infinity, + trusted => boolean(), + oauth_exchange => boolean(), + oauth_exchange_url => binary() | undefined }. -spec default_config() -> config(). @@ -137,5 +140,8 @@ default_config() -> tarball_max_size => 16 * 1024 * 1024, tarball_max_uncompressed_size => 128 * 1024 * 1024, docs_tarball_max_size => 16 * 1024 * 1024, - docs_tarball_max_uncompressed_size => 128 * 1024 * 1024 + docs_tarball_max_uncompressed_size => 128 * 1024 * 1024, + trusted => true, + oauth_exchange => true, + oauth_exchange_url => undefined }. diff --git a/test/hex_api_SUITE.erl b/test/hex_api_SUITE.erl index 9503f5b..78412f3 100644 --- a/test/hex_api_SUITE.erl +++ b/test/hex_api_SUITE.erl @@ -29,6 +29,9 @@ all() -> auth_test, short_url_test, oauth_device_flow_test, + oauth_device_auth_flow_success_test, + oauth_device_auth_flow_denied_test, + oauth_device_auth_flow_timeout_test, oauth_refresh_token_test, oauth_revoke_test, oauth_client_credentials_test, @@ -145,6 +148,82 @@ oauth_device_flow_test(_Config) -> #{<<"error">> := <<"authorization_pending">>} = PollResponse, ok. +oauth_device_auth_flow_success_test(_Config) -> + ClientId = <<"cli">>, + Scope = <<"api:write">>, + Self = self(), + PromptUser = fun(VerificationUri, UserCode) -> + Self ! {prompt_called, VerificationUri, UserCode}, + ok + end, + + % Queue a success response for when polling happens + AccessToken = <<"test_access_token">>, + RefreshToken = <<"test_refresh_token">>, + SuccessPayload = #{ + <<"access_token">> => AccessToken, + <<"refresh_token">> => RefreshToken, + <<"token_type">> => <<"Bearer">>, + <<"expires_in">> => 3600 + }, + Headers = #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>}, + Self ! + {hex_http_test, oauth_device_response, + {ok, {200, Headers, term_to_binary(SuccessPayload)}}}, + + {ok, Tokens} = hex_api_oauth:device_auth_flow(?CONFIG, ClientId, Scope, PromptUser), + + % Verify prompt was called + receive + {prompt_called, _Uri, _Code} -> ok + after 100 -> + error(prompt_not_called) + end, + + % Verify tokens + #{access_token := AccessToken, refresh_token := RefreshToken, expires_at := ExpiresAt} = Tokens, + ?assert(is_integer(ExpiresAt)), + ?assert(ExpiresAt > erlang:system_time(second)), + ok. + +oauth_device_auth_flow_denied_test(_Config) -> + ClientId = <<"cli">>, + Scope = <<"api:write">>, + Self = self(), + PromptUser = fun(_VerificationUri, _UserCode) -> ok end, + + % Queue an access denied response + ErrorPayload = #{ + <<"error">> => <<"access_denied">>, + <<"error_description">> => <<"User denied access">> + }, + Headers = #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>}, + Self ! + {hex_http_test, oauth_device_response, {ok, {403, Headers, term_to_binary(ErrorPayload)}}}, + + {error, {access_denied, 403, _Body}} = hex_api_oauth:device_auth_flow( + ?CONFIG, ClientId, Scope, PromptUser + ), + ok. + +oauth_device_auth_flow_timeout_test(_Config) -> + ClientId = <<"cli">>, + Scope = <<"api:write">>, + Self = self(), + PromptUser = fun(_VerificationUri, _UserCode) -> ok end, + + % Queue an expired token response + ErrorPayload = #{ + <<"error">> => <<"expired_token">>, + <<"error_description">> => <<"Device code expired">> + }, + Headers = #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>}, + Self ! + {hex_http_test, oauth_device_response, {ok, {400, Headers, term_to_binary(ErrorPayload)}}}, + + {error, timeout} = hex_api_oauth:device_auth_flow(?CONFIG, ClientId, Scope, PromptUser), + ok. + oauth_refresh_token_test(_Config) -> % Test token refresh ClientId = <<"cli">>, diff --git a/test/hex_cli_auth_SUITE.erl b/test/hex_cli_auth_SUITE.erl new file mode 100644 index 0000000..13099ab --- /dev/null +++ b/test/hex_cli_auth_SUITE.erl @@ -0,0 +1,684 @@ +-module(hex_cli_auth_SUITE). + +-compile([export_all]). + +-include_lib("eunit/include/eunit.hrl"). +-include_lib("common_test/include/ct.hrl"). + +-define(DEFAULT_HTTP_ADAPTER_CONFIG, #{profile => default}). + +-define(CONFIG, (hex_core:default_config())#{ + http_adapter => {hex_http_test, ?DEFAULT_HTTP_ADAPTER_CONFIG}, + http_user_agent_fragment => <<"(test)">>, + api_url => <<"https://api.test">>, + repo_url => <<"https://repo.test">>, + repo_name => <<"hexpm">>, + repo_public_key => ct:get_config({ssl_certs, test_pub}) +}). + +suite() -> + [{require, {ssl_certs, [test_pub, test_priv]}}]. + +all() -> + [ + %% resolve_api_auth tests + resolve_api_auth_config_passthrough_test, + resolve_api_auth_per_repo_test, + resolve_api_auth_parent_repo_test, + resolve_api_auth_oauth_test, + resolve_api_auth_oauth_expired_refresh_test, + resolve_api_auth_oauth_no_refresh_token_test, + resolve_api_auth_no_auth_test, + + %% resolve_repo_auth tests - trusted vs untrusted + resolve_repo_auth_config_passthrough_test, + resolve_repo_auth_callback_repo_key_test, + resolve_repo_auth_trusted_auth_key_test, + resolve_repo_auth_untrusted_ignores_auth_key_test, + resolve_repo_auth_oauth_fallback_test, + resolve_repo_auth_no_auth_test, + + %% resolve_repo_auth tests - token exchange + resolve_repo_auth_oauth_exchange_new_token_test, + resolve_repo_auth_oauth_exchange_existing_valid_test, + resolve_repo_auth_oauth_exchange_existing_expired_test, + resolve_repo_auth_parent_repo_auth_key_test, + + %% with_api tests - OTP handling + with_api_otp_required_test, + with_api_otp_invalid_retry_test, + with_api_otp_cancelled_test, + with_api_otp_max_retries_test, + + %% with_api tests - token refresh on 401 + with_api_token_expired_refresh_test, + + %% with_api tests - wrapper behavior + with_api_optional_test, + with_api_auth_inline_test, + with_api_device_auth_test, + + %% with_repo tests - wrapper behavior + with_repo_optional_test, + with_repo_trusted_with_auth_test + ]. + +%%==================================================================== +%% Test Cases - resolve_api_auth +%%==================================================================== + +resolve_api_auth_config_passthrough_test(_Config) -> + %% When api_key is already in config, it should be used directly + Callbacks = make_callbacks(#{}), + ConfigWithKey = ?CONFIG#{api_key => <<"config_api_key">>}, + + {ok, ApiKey, AuthContext} = hex_cli_auth:resolve_api_auth(Callbacks, read, ConfigWithKey), + ?assertEqual(<<"config_api_key">>, ApiKey), + ?assertEqual(#{source => config, has_refresh_token => false}, AuthContext), + ok. + +resolve_api_auth_per_repo_test(_Config) -> + %% Test per-repo api_key from callback + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{api_key => <<"repo_api_key">>}} + }), + + {ok, ApiKey, AuthContext} = hex_cli_auth:resolve_api_auth(Callbacks, write, ?CONFIG), + ?assertEqual(<<"repo_api_key">>, ApiKey), + ?assertEqual(#{source => config, has_refresh_token => false}, AuthContext), + ok. + +resolve_api_auth_parent_repo_test(_Config) -> + %% Test parent repo fallback for "hexpm:org" repos + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{api_key => <<"parent_api_key">>}} + }), + + {ok, ApiKey, _} = hex_cli_auth:resolve_api_auth( + Callbacks, write, ?CONFIG#{repo_name => <<"hexpm:myorg">>} + ), + ?assertEqual(<<"parent_api_key">>, ApiKey), + ok. + +resolve_api_auth_oauth_test(_Config) -> + %% Test OAuth token fallback with valid token + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"oauth_token">>, + refresh_token => <<"refresh_token">>, + expires_at => Now + 3600 + }} + }), + + {ok, ApiKey, AuthContext} = hex_cli_auth:resolve_api_auth(Callbacks, read, ?CONFIG), + ?assertEqual(<<"Bearer oauth_token">>, ApiKey), + ?assertEqual(#{source => oauth, has_refresh_token => true}, AuthContext), + ok. + +resolve_api_auth_oauth_expired_refresh_test(_Config) -> + %% Test OAuth token refresh when expired + Now = erlang:system_time(second), + Self = self(), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"expired_token">>, + refresh_token => <<"refresh_token">>, + %% Expired + expires_at => Now - 100 + }}, + persist_oauth_tokens => fun(Scope, Access, Refresh, Expires) -> + Self ! {persisted, Scope, Access, Refresh, Expires}, + ok + end + }), + + {ok, ApiKey, AuthContext} = hex_cli_auth:resolve_api_auth(Callbacks, read, ?CONFIG), + %% Should have refreshed and got a new token + ?assertMatch(<<"Bearer ", _/binary>>, ApiKey), + ?assertEqual(#{source => oauth, has_refresh_token => true}, AuthContext), + + %% Verify token was persisted + receive + {persisted, global, _NewAccess, _NewRefresh, _NewExpires} -> ok + after 100 -> + error(token_not_persisted) + end, + ok. + +resolve_api_auth_oauth_no_refresh_token_test(_Config) -> + %% Test OAuth token without refresh token + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"oauth_token">>, + expires_at => Now + 3600 + }} + }), + + {ok, ApiKey, AuthContext} = hex_cli_auth:resolve_api_auth(Callbacks, read, ?CONFIG), + ?assertEqual(<<"Bearer oauth_token">>, ApiKey), + ?assertEqual(#{source => oauth, has_refresh_token => false}, AuthContext), + ok. + +resolve_api_auth_no_auth_test(_Config) -> + %% Test when no auth is available + Callbacks = make_callbacks(#{ + auth_config => #{}, + oauth_tokens => error + }), + + Result = hex_cli_auth:resolve_api_auth(Callbacks, read, ?CONFIG), + ?assertEqual({error, no_auth}, Result), + ok. + +%%==================================================================== +%% Test Cases - resolve_repo_auth +%%==================================================================== + +resolve_repo_auth_config_passthrough_test(_Config) -> + %% When repo_key is already in config, it should be used directly + Callbacks = make_callbacks(#{}), + ConfigWithKey = ?CONFIG#{repo_key => <<"config_repo_key">>}, + + {ok, RepoKey, AuthContext} = hex_cli_auth:resolve_repo_auth(Callbacks, ConfigWithKey), + ?assertEqual(<<"config_repo_key">>, RepoKey), + ?assertEqual(#{source => config, has_refresh_token => false}, AuthContext), + ok. + +resolve_repo_auth_callback_repo_key_test(_Config) -> + %% Test repo_key from get_auth_config callback + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{repo_key => <<"callback_repo_key">>}} + }), + + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, ?CONFIG#{trusted => true}), + ?assertEqual(<<"callback_repo_key">>, RepoKey), + ok. + +resolve_repo_auth_trusted_auth_key_test(_Config) -> + %% Test trusted + auth_key (no oauth_exchange) uses auth_key directly + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{auth_key => <<"auth_key_value">>}} + }), + + Config = ?CONFIG#{trusted => true, oauth_exchange => false}, + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(<<"auth_key_value">>, RepoKey), + ok. + +resolve_repo_auth_untrusted_ignores_auth_key_test(_Config) -> + %% Test untrusted config ignores auth_key even when present + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{auth_key => <<"auth_key_value">>}}, + oauth_tokens => error + }), + + Config = ?CONFIG#{trusted => false}, + Result = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(no_auth, Result), + ok. + +resolve_repo_auth_oauth_fallback_test(_Config) -> + %% Test fallback to global OAuth when trusted but no auth_key + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + auth_config => #{}, + oauth_tokens => + {ok, #{ + access_token => <<"global_oauth">>, + expires_at => Now + 3600 + }} + }), + + Config = ?CONFIG#{trusted => true}, + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(<<"Bearer global_oauth">>, RepoKey), + ok. + +resolve_repo_auth_no_auth_test(_Config) -> + %% Test no_auth when untrusted and no credentials + Callbacks = make_callbacks(#{ + auth_config => #{}, + oauth_tokens => error + }), + + Config = ?CONFIG#{trusted => false}, + Result = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(no_auth, Result), + ok. + +resolve_repo_auth_oauth_exchange_new_token_test(_Config) -> + %% Test oauth_exchange with auth_key but no existing oauth_token + Self = self(), + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{auth_key => <<"my_auth_key">>}}, + persist_oauth_tokens => fun(Scope, Access, _Refresh, Expires) -> + Self ! {persisted, Scope, Access, Expires}, + ok + end + }), + + Config = ?CONFIG#{trusted => true, oauth_exchange => true}, + {ok, RepoKey, AuthContext} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertMatch(<<"Bearer ", _/binary>>, RepoKey), + ?assertEqual(#{source => oauth, has_refresh_token => false}, AuthContext), + + %% Verify token was persisted with repo name + receive + {persisted, <<"hexpm">>, _AccessToken, _ExpiresAt} -> ok + after 100 -> + error(token_not_persisted) + end, + ok. + +resolve_repo_auth_oauth_exchange_existing_valid_test(_Config) -> + %% Test oauth_exchange with existing valid oauth_token reuses it + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + auth_config => #{ + <<"hexpm">> => #{ + auth_key => <<"my_auth_key">>, + oauth_token => #{ + access_token => <<"existing_token">>, + expires_at => Now + 3600 + } + } + } + }), + + Config = ?CONFIG#{trusted => true, oauth_exchange => true}, + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(<<"Bearer existing_token">>, RepoKey), + ok. + +resolve_repo_auth_oauth_exchange_existing_expired_test(_Config) -> + %% Test oauth_exchange with expired oauth_token re-exchanges + Now = erlang:system_time(second), + Self = self(), + Callbacks = make_callbacks(#{ + auth_config => #{ + <<"hexpm">> => #{ + auth_key => <<"my_auth_key">>, + oauth_token => #{ + access_token => <<"expired_token">>, + expires_at => Now - 100 + } + } + }, + persist_oauth_tokens => fun(Scope, Access, _Refresh, Expires) -> + Self ! {persisted, Scope, Access, Expires}, + ok + end + }), + + Config = ?CONFIG#{trusted => true, oauth_exchange => true}, + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertMatch(<<"Bearer ", _/binary>>, RepoKey), + ?assertNotEqual(<<"Bearer expired_token">>, RepoKey), + + %% Verify new token was persisted + receive + {persisted, <<"hexpm">>, _NewAccessToken, _ExpiresAt} -> ok + after 100 -> + error(token_not_persisted) + end, + ok. + +resolve_repo_auth_parent_repo_auth_key_test(_Config) -> + %% Test trusted org repo falls back to parent repo auth_key + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{auth_key => <<"parent_auth_key">>}} + }), + + Config = ?CONFIG#{repo_name => <<"hexpm:myorg">>, trusted => true, oauth_exchange => false}, + {ok, RepoKey, _} = hex_cli_auth:resolve_repo_auth(Callbacks, Config), + ?assertEqual(<<"parent_auth_key">>, RepoKey), + ok. + +%%==================================================================== +%% Test Cases - with_api OTP handling +%%==================================================================== + +with_api_otp_required_test(_Config) -> + %% Test OTP prompt when server returns otp_required + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"token">>, + expires_at => Now + 3600 + }}, + prompt_otp => fun(_Msg) -> {ok, <<"123456">>} end + }), + + CallCount = counters:new(1, []), + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(Config) -> + Count = counters:get(CallCount, 1), + counters:add(CallCount, 1, 1), + case Count of + 0 -> + %% First call: return 401 with otp_required + {ok, + {401, + #{ + <<"www-authenticate">> => + <<"Bearer realm=\"hex\", error=\"totp_required\"">> + }, + <<>>}}; + _ -> + %% Second call: should have OTP, return success + ?assertEqual(<<"123456">>, maps:get(api_otp, Config)), + {ok, {200, #{}, <<"success">>}} + end + end + ), + ?assertEqual({ok, {200, #{}, <<"success">>}}, Result), + ?assertEqual(2, counters:get(CallCount, 1)), + ok. + +with_api_otp_invalid_retry_test(_Config) -> + %% Test OTP retry when server returns invalid_totp + Now = erlang:system_time(second), + OtpAttempts = counters:new(1, []), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"token">>, + expires_at => Now + 3600 + }}, + prompt_otp => fun(_Msg) -> + Count = counters:get(OtpAttempts, 1), + counters:add(OtpAttempts, 1, 1), + case Count of + 0 -> {ok, <<"wrong_otp">>}; + _ -> {ok, <<"correct_otp">>} + end + end + }), + + CallCount = counters:new(1, []), + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(Config) -> + Count = counters:get(CallCount, 1), + counters:add(CallCount, 1, 1), + case Count of + 0 -> + {ok, + {401, + #{ + <<"www-authenticate">> => + <<"Bearer realm=\"hex\", error=\"totp_required\"">> + }, + <<>>}}; + 1 -> + ?assertEqual(<<"wrong_otp">>, maps:get(api_otp, Config)), + {ok, + {401, + #{ + <<"www-authenticate">> => + <<"Bearer realm=\"hex\", error=\"invalid_totp\"">> + }, + <<>>}}; + _ -> + ?assertEqual(<<"correct_otp">>, maps:get(api_otp, Config)), + {ok, {200, #{}, <<"success">>}} + end + end + ), + ?assertEqual({ok, {200, #{}, <<"success">>}}, Result), + ?assertEqual(3, counters:get(CallCount, 1)), + ok. + +with_api_otp_cancelled_test(_Config) -> + %% Test OTP cancellation returns error + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"token">>, + expires_at => Now + 3600 + }}, + prompt_otp => fun(_Msg) -> cancelled end + }), + + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(_Cfg) -> + {ok, + {401, + #{ + <<"www-authenticate">> => + <<"Bearer realm=\"hex\", error=\"totp_required\"">> + }, + <<>>}} + end + ), + ?assertEqual({error, {auth_error, otp_cancelled}}, Result), + ok. + +with_api_otp_max_retries_test(_Config) -> + %% Test OTP max retries returns error + Now = erlang:system_time(second), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"token">>, + expires_at => Now + 3600 + }}, + prompt_otp => fun(_Msg) -> {ok, <<"wrong_otp">>} end + }), + + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(_Cfg) -> + {ok, + {401, + #{<<"www-authenticate">> => <<"Bearer realm=\"hex\", error=\"invalid_totp\"">>}, + <<>>}} + end + ), + ?assertEqual({error, {auth_error, otp_max_retries}}, Result), + ok. + +%%==================================================================== +%% Test Cases - with_api token refresh on 401 +%%==================================================================== + +with_api_token_expired_refresh_test(_Config) -> + %% Test token refresh when server returns token_expired + Now = erlang:system_time(second), + Self = self(), + Callbacks = make_callbacks(#{ + oauth_tokens => + {ok, #{ + access_token => <<"initial_token">>, + refresh_token => <<"refresh_token">>, + %% Within EXPIRY_BUFFER_SECONDS, will trigger refresh + expires_at => Now + 100 + }}, + persist_oauth_tokens => fun(_Scope, Access, Refresh, Expires) -> + Self ! {persisted, Access, Refresh, Expires}, + ok + end + }), + + CallCount = counters:new(1, []), + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(Config) -> + Count = counters:get(CallCount, 1), + counters:add(CallCount, 1, 1), + ApiKey = maps:get(api_key, Config), + case Count of + 0 -> + %% First call gets refreshed token (initial was within expiry buffer) + ?assertMatch(<<"Bearer ", _/binary>>, ApiKey), + ?assertNotEqual(<<"Bearer initial_token">>, ApiKey), + {ok, + {401, + #{ + <<"www-authenticate">> => + <<"Bearer realm=\"hex\", error=\"token_expired\"">> + }, + <<>>}}; + _ -> + %% Second refresh after 401 + ?assertMatch(<<"Bearer ", _/binary>>, ApiKey), + {ok, {200, #{}, <<"success">>}} + end + end + ), + ?assertEqual({ok, {200, #{}, <<"success">>}}, Result), + ?assertEqual(2, counters:get(CallCount, 1)), + + %% Verify tokens were persisted (at least once for initial refresh) + receive + {persisted, _NewAccess, _NewRefresh, _NewExpires} -> ok + after 100 -> + error(token_not_persisted) + end, + ok. + +%%==================================================================== +%% Test Cases - with_api (wrapper behavior) +%%==================================================================== + +with_api_optional_test(_Config) -> + %% Test optional => true allows requests without auth + Callbacks = make_callbacks(#{oauth_tokens => error}), + + %% Function is called without api_key + Result = hex_cli_auth:with_api( + Callbacks, + read, + ?CONFIG, + fun(Config) -> maps:get(api_key, Config, undefined) end, + [{optional, true}] + ), + ?assertEqual(undefined, Result), + ok. + +with_api_auth_inline_test(_Config) -> + %% Test auth_inline => false returns error instead of prompting + Callbacks = make_callbacks(#{oauth_tokens => error}), + + Result = hex_cli_auth:with_api( + Callbacks, + read, + ?CONFIG, + fun(_) -> error(should_not_be_called) end, + [{optional, false}, {auth_inline, false}] + ), + ?assertEqual({error, {auth_error, no_credentials}}, Result), + ok. + +with_api_device_auth_test(_Config) -> + %% Test device auth flow when should_authenticate returns true + Self = self(), + Callbacks = make_callbacks(#{ + oauth_tokens => error, + should_authenticate => fun(no_credentials) -> true end, + persist_oauth_tokens => fun(Scope, Access, Refresh, Expires) -> + Self ! {persisted, Scope, Access, Refresh, Expires}, + ok + end + }), + + %% Queue success response for device auth polling + AccessToken = <<"device_token">>, + RefreshToken = <<"device_refresh">>, + SuccessPayload = #{ + <<"access_token">> => AccessToken, + <<"refresh_token">> => RefreshToken, + <<"token_type">> => <<"Bearer">>, + <<"expires_in">> => 3600 + }, + Headers = #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>}, + Self ! + {hex_http_test, oauth_device_response, + {ok, {200, Headers, term_to_binary(SuccessPayload)}}}, + + Result = hex_cli_auth:with_api( + Callbacks, + write, + ?CONFIG, + fun(Config) -> maps:get(api_key, Config) end, + [{oauth_open_browser, false}] + ), + ?assertEqual(<<"Bearer device_token">>, Result), + + %% Verify token was persisted + receive + {persisted, global, AccessToken, RefreshToken, _} -> ok + after 100 -> + error(token_not_persisted) + end, + ok. + +%%==================================================================== +%% Test Cases - with_repo (wrapper behavior) +%%==================================================================== + +with_repo_optional_test(_Config) -> + %% Test that with_repo with optional => true (default) proceeds without auth + Callbacks = make_callbacks(#{oauth_tokens => error}), + + Result = hex_cli_auth:with_repo( + Callbacks, + ?CONFIG#{trusted => false}, + fun(Config) -> maps:get(repo_key, Config, undefined) end + ), + ?assertEqual(undefined, Result), + ok. + +with_repo_trusted_with_auth_test(_Config) -> + %% Test with_repo with trusted config and auth_key + Callbacks = make_callbacks(#{ + auth_config => #{<<"hexpm">> => #{auth_key => <<"my_auth_key">>}} + }), + + Result = hex_cli_auth:with_repo( + Callbacks, + ?CONFIG#{trusted => true, oauth_exchange => false}, + fun(Config) -> maps:get(repo_key, Config) end + ), + ?assertEqual(<<"my_auth_key">>, Result), + ok. + +%%==================================================================== +%% Helper Functions +%%==================================================================== + +make_callbacks(Opts) -> + AuthConfig = maps:get(auth_config, Opts, #{}), + PromptOtp = maps:get(prompt_otp, Opts, fun(_) -> cancelled end), + ShouldAuthenticate = maps:get(should_authenticate, Opts, fun(_) -> false end), + PersistFn = maps:get(persist_oauth_tokens, Opts, fun(_, _, _, _) -> ok end), + DefaultGetOAuthTokens = fun() -> maps:get(oauth_tokens, Opts, error) end, + GetOAuthTokensFn = maps:get(get_oauth_tokens, Opts, DefaultGetOAuthTokens), + + #{ + get_auth_config => fun(RepoName) -> maps:get(RepoName, AuthConfig, undefined) end, + get_oauth_tokens => GetOAuthTokensFn, + persist_oauth_tokens => PersistFn, + prompt_otp => PromptOtp, + should_authenticate => ShouldAuthenticate, + get_client_id => fun() -> <<"test_client">> end + }. diff --git a/test/support/hex_http_test.erl b/test/support/hex_http_test.erl index ecd6c1d..e92d205 100644 --- a/test/support/hex_http_test.erl +++ b/test/support/hex_http_test.erl @@ -306,7 +306,7 @@ fixture(post, <>, _, {_, Body}) -> <<"verification_uri">> => <<"https://hex.pm/oauth/device">>, <<"verification_uri_complete">> => <<"https://hex.pm/oauth/device?user_code=", UserCode/binary>>, <<"expires_in">> => 600, - <<"interval">> => 5 + <<"interval">> => 0 }, {ok, {200, api_headers(), term_to_binary(Payload)}}; @@ -314,12 +314,17 @@ fixture(post, <>, _, {_, Body}) -> DecodedBody = binary_to_term(Body), case maps:get(<<"grant_type">>, DecodedBody) of <<"urn:ietf:params:oauth:grant-type:device_code">> -> - % Simulate pending authorization - ErrorPayload = #{ - <<"error">> => <<"authorization_pending">>, - <<"error_description">> => <<"Authorization pending">> - }, - {ok, {400, api_headers(), term_to_binary(ErrorPayload)}}; + receive + {hex_http_test, oauth_device_response, Response} -> + Response + after 0 -> + % Default: simulate pending authorization + ErrorPayload = #{ + <<"error">> => <<"authorization_pending">>, + <<"error_description">> => <<"Authorization pending">> + }, + {ok, {400, api_headers(), term_to_binary(ErrorPayload)}} + end; <<"urn:ietf:params:oauth:grant-type:token-exchange">> -> % Simulate successful token exchange AccessToken = base64:encode(crypto:strong_rand_bytes(32)),