Skip to content

Commit ce84327

Browse files
NRHelmiquinnj
andauthored
Cache access token on disk (#85)
* cache access token on disk * cleanup * fix tests * Update src/rest.jl Co-authored-by: Jacob Quinn <quinn.jacobd@gmail.com> * addressing PR comments * Update src/rest.jl Co-authored-by: Jacob Quinn <quinn.jacobd@gmail.com> * fix read access token issue && add token caching test * init rai config folder * missing shell env Co-authored-by: Jacob Quinn <quinn.jacobd@gmail.com>
1 parent 85c2a17 commit ce84327

File tree

4 files changed

+97
-16
lines changed

4 files changed

+97
-16
lines changed

.github/actions/test/action.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ runs:
3838
- uses: julia-actions/cache@v1
3939
- uses: julia-actions/julia-buildpkg@v1
4040

41+
# this folder is required to test access token caching
42+
- name: Init rai config folder
43+
run: mkdir -p ~/.rai
44+
shell: bash
45+
4146
- name: Test
4247
uses: julia-actions/julia-runtest@v1
4348
env:

src/creds.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
using Dates: DateTime, Second
15+
using Dates: datetime2unix
1616

1717
"""
1818
AccessToken
@@ -23,21 +23,21 @@ but we can still print the secret if needed:
2323
2424
Example:
2525
```
26-
access_token.token
26+
t.access_token
2727
````
2828
"""
2929
struct AccessToken
30-
token::String
30+
access_token::String
3131
scope::String
3232
expires_in::Int # seconds
33-
created_on::DateTime
33+
created_on::Float64
3434
end
3535

3636
function Base.show(io::IO, t::AccessToken)
3737
print(
3838
io,
3939
"(",
40-
isempty(t.token) ? "" : "$(t.token[1:3])...",
40+
isempty(t.access_token) ? "" : "$(t.access_token[1:3])...",
4141
", ", t.scope,
4242
", ", t.expires_in,
4343
", ", t.created_on,
@@ -46,8 +46,8 @@ function Base.show(io::IO, t::AccessToken)
4646
end
4747

4848
function isexpired(access_token::AccessToken)::Bool
49-
expires_on = access_token.created_on + Second(access_token.expires_in)
50-
return expires_on - Second(5) < now() # anticipate token expiration by 5 seconds
49+
expires_on = access_token.created_on + access_token.expires_in
50+
return expires_on - 5 < datetime2unix(now()) # anticipate token expiration by 5 seconds
5151
end
5252

5353
abstract type Credentials end
@@ -78,8 +78,8 @@ function Base.show(io::IO, c::ClientCredentials)
7878
io,
7979
"(",
8080
c.client_id,
81-
c.client_secret == nothing ? "" : ", $(c.client_secret[1:3])...",
82-
c.access_token == nothing ? "" : ", $(c.access_token)",
81+
c.client_secret === nothing ? "" : ", $(c.client_secret[1:3])...",
82+
c.access_token === nothing ? "" : ", $(c.access_token)",
8383
", ",
8484
c.client_credentials_url,
8585
")"

src/rest.jl

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Low level HTTP interface to the RAI REST API. Handles authentication of
1616
# requests and other protocol level details.
1717

18-
using Dates: now
18+
using Dates: now, datetime2unix
1919
import HTTP
2020
import JSON3
2121

@@ -81,7 +81,64 @@ function get_access_token(ctx::Context, creds::ClientCredentials)::AccessToken
8181
opts = (redirect = false, retry_non_idempotent = true, connect_timeout = 30, readtimeout = 30, keepalive = true)
8282
rsp = HTTP.request("POST", url, h, body; opts...)
8383
data = JSON3.read(rsp.body)
84-
return AccessToken(data.access_token, data.scope, data.expires_in, now())
84+
return AccessToken(data.access_token, data.scope, data.expires_in, datetime2unix(now()))
85+
end
86+
87+
# cache name
88+
function _cache_file()
89+
return joinpath(homedir(), ".rai", "tokens.json")
90+
end
91+
92+
# read oauth cache
93+
function _read_cache()
94+
try
95+
if isfile(_cache_file())
96+
return copy(JSON3.read(read(_cache_file())))
97+
else
98+
return nothing
99+
end
100+
catch e
101+
@warn e
102+
return nothing
103+
end
104+
end
105+
106+
# Read access token from cache
107+
function _read_token_cache(creds::ClientCredentials)
108+
try
109+
cache = _read_cache()
110+
cache === nothing && return nothing
111+
112+
if haskey(cache, Symbol(creds.client_id))
113+
access_token = cache[Symbol(creds.client_id)]
114+
return AccessToken(
115+
access_token[:access_token],
116+
access_token[:scope],
117+
access_token[:expires_in],
118+
access_token[:created_on],
119+
)
120+
else
121+
return nothing
122+
end
123+
catch e
124+
@warn e
125+
return nothing
126+
end
127+
end
128+
129+
# Write access token to cache
130+
function _write_token_cache(creds::ClientCredentials)
131+
try
132+
cache = _read_cache()
133+
if cache === nothing
134+
cache = Dict(creds.client_id => creds.access_token)
135+
else
136+
cache[Symbol(creds.client_id)] = creds.access_token
137+
end
138+
write(_cache_file(), JSON3.write(cache))
139+
catch e
140+
@warn e
141+
end
85142
end
86143

87144
function _get_client_credentials_url(creds::ClientCredentials)
@@ -102,13 +159,19 @@ function _authenticate!(
102159
headers,
103160
)::Nothing
104161
if isnothing(creds.access_token)
105-
creds.access_token = get_access_token(ctx, creds)
162+
creds.access_token = _read_token_cache(creds)
163+
if isnothing(creds.access_token)
164+
creds.access_token = get_access_token(ctx, creds)
165+
_write_token_cache(creds)
166+
end
106167
end
107168

108169
if isexpired(creds.access_token)
109170
creds.access_token = get_access_token(ctx, creds)
171+
_write_token_cache(creds)
110172
end
111-
push!(headers, "Authorization" => "Bearer $(creds.access_token.token)")
173+
174+
push!(headers, "Authorization" => "Bearer $(creds.access_token.access_token)")
112175
return nothing
113176
end
114177

test/api.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using JSON3
55
using Mocking
66
using Dates
77
using RAI.protocol
8-
using RAI: _poll_with_specified_overhead
8+
using RAI: _poll_with_specified_overhead, _write_token_cache, _read_token_cache
99

1010
using RAI: TransactionResponse
1111

@@ -287,11 +287,24 @@ end
287287
end
288288

289289
@testset "hide client secrets in repl" begin
290-
access_token = AccessToken("abc_token", "run:transaction", 3600, DateTime("2022-08-12T17:49:51.365"))
290+
access_token = AccessToken("abc_token", "run:transaction", 3600, datetime2unix(DateTime("2022-08-12T17:49:51.365")))
291291
creds = ClientCredentials("client_id", "xyz_client_secret", "https://login.relationalai.com/oauth/token")
292292
creds.access_token = access_token
293293

294294
io = IOBuffer()
295295
show(io, creds)
296-
@test String(take!(io)) === "(client_id, xyz..., (abc..., run:transaction, 3600, 2022-08-12T17:49:51.365), https://login.relationalai.com/oauth/token)"
296+
@test String(take!(io)) === "(client_id, xyz..., (abc..., run:transaction, 3600, 1.660326591365e9), https://login.relationalai.com/oauth/token)"
297+
end
298+
299+
@testset "read write access token to cache" begin
300+
access_token = AccessToken("abc_token", "run:transaction", 3600, datetime2unix(DateTime("2022-08-12T17:49:51.365")))
301+
creds = ClientCredentials("client_id", "xyz_client_secret", "https://login.relationalai.com/oauth/token")
302+
creds.access_token = access_token
303+
304+
# write/read access token to cache
305+
_write_token_cache(creds)
306+
cached_token = _read_token_cache(creds)
307+
308+
# check if access token is serialized/de-serialized correctly from the cache
309+
@test cached_token === access_token
297310
end

0 commit comments

Comments
 (0)