Skip to content

Commit

Permalink
api client reauthentication
Browse files Browse the repository at this point in the history
  • Loading branch information
ennasus4sun committed Aug 12, 2024
1 parent eeeadb0 commit 8709091
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 108 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
- NEXT-37316 - Added `index` command, to trigger the indexing of the Shopware shop
- NEXT-37315 - Trigger indexing of the shop by default at the end of an import (can be disabled with flag `-d` `--disable-index`)
- NEXT-37303 - [BREAKING] changed `sync` command argument `--schema` `-s` to `--profile` `-p`
- NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script.
Now they are converted into their proper types before passed to the script.
- NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script.
Now they are converted into their proper types before passed to the script.
- NEXT-37313 - Implemented re-authentication for API calls to handle expired bearer tokens

# v0.7.1

Expand Down
235 changes: 129 additions & 106 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ pub mod filter;

use crate::api::filter::{Criteria, CriteriaFilter};
use crate::config_file::Credentials;
use reqwest::blocking::{Client, Response};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{header, StatusCode};
use reqwest::{header, Method, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
Expand Down Expand Up @@ -52,24 +52,23 @@ impl SwClient {

let total = self.get_total("language", &[])?;

let access_token = self.access_token.lock().unwrap().clone();
while language_list.len() < total as usize {
let mut criteria = Criteria {
let criteria = Criteria {
page,
limit: Some(Criteria::MAX_LIMIT),
fields: vec!["id".to_string(), "locale.code".to_string()],
..Default::default()
};

criteria.add_association("locale");
let request_builder = self
.client
.request(
Method::POST,
format!("{}/api/search/language", self.credentials.base_url),
)
.json(&criteria);

let response = {
self.client
.post(format!("{}/api/search/language", self.credentials.base_url))
.bearer_auth(&access_token)
.json(&criteria)
.send()?
};
let response = self.handle_authenticated_request(request_builder)?;

let value: LanguageLocaleSearchResponse = Self::deserialize(response)?;
for item in value.data {
Expand All @@ -84,15 +83,13 @@ impl SwClient {
})
}

pub fn sync<S: Into<String>, T: Serialize>(
pub fn sync<S: Into<String>, T: Serialize + Debug>(
&self,
entity: S,
action: SyncAction,
payload: &[T],
) -> Result<(), SwApiError> {
let entity: String = entity.into();
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let body = SyncBody {
write_data: SyncOperation {
entity: entity.clone(),
Expand All @@ -101,29 +98,25 @@ impl SwClient {
},
};

let response = {
let start_instant = Instant::now();
println!(
"sync {:?} '{}' with payload size {}",
action,
&entity,
payload.len()
);
let res = self
.client
.post(format!("{}/api/_action/sync", self.credentials.base_url))
.bearer_auth(access_token)
.header("single-operation", 1)
.header("indexing-behavior", "disable-indexing")
.header("sw-skip-trigger-flow", 1)
.json(&body)
.send()?;
println!(
"sync request finished after {} ms",
start_instant.elapsed().as_millis()
);
res
};
println!(
"sync {:?} '{}' with payload size {}",
action,
&entity,
payload.len()
);

let request_builder = self
.client
.request(
Method::POST,
format!("{}/api/_action/sync", self.credentials.base_url),
)
.header("single-operation", "1")
.header("indexing-behavior", "disable-indexing")
.header("sw-skip-trigger-flow", "1")
.json(&body);

let response= self.handle_authenticated_request(request_builder)?;

if !response.status().is_success() {
let status = response.status();
Expand All @@ -135,17 +128,12 @@ impl SwClient {
}

pub fn entity_schema(&self) -> Result<Entity, SwApiError> {
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = {
self.client
.get(format!(
"{}/api/_info/entity-schema.json",
self.credentials.base_url
))
.bearer_auth(access_token)
.send()?
};
let request_builder = self.client.request(
Method::GET,
format!("{}/api/_info/entity-schema.json", self.credentials.base_url),
);

let response= self.handle_authenticated_request(request_builder)?;

if !response.status().is_success() {
let status = response.status();
Expand All @@ -160,30 +148,27 @@ impl SwClient {
pub fn get_total(&self, entity: &str, filter: &[CriteriaFilter]) -> Result<u64, SwApiError> {
// entity needs to be provided as kebab-case instead of snake_case
let entity = entity.replace('_', "-");
let body = json!({
"limit": 1,
"filter": filter,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
});

// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();

let response = {
self.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"limit": 1,
"filter": filter,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
}))
.send()?
};
let request_builder = self
.client
.request(
Method::POST,
format!("{}/api/search/{}", self.credentials.base_url, entity),
)
.json(&body);

let response = self.handle_authenticated_request(request_builder)?;

if !response.status().is_success() {
let status = response.status();
Expand All @@ -207,35 +192,24 @@ impl SwClient {
// entity needs to be provided as kebab-case instead of snake_case
let entity = entity.replace('_', "-");

// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = {
let start_instant = Instant::now();

if let Some(limit) = criteria.limit {
println!(
"fetching page {} of '{}' with limit {}",
criteria.page, entity, limit
);
} else {
println!("fetching page {} of '{}'", criteria.page, entity);
}

let res = self
.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(criteria)
.send()?;
if let Some(limit) = criteria.limit {
println!(
"search request finished after {} ms",
start_instant.elapsed().as_millis()
"fetching page {} of '{}' with limit {}",
criteria.page, entity, limit
);
res
};
} else {
println!("fetching page {} of '{}'", criteria.page, entity);
}

let request_builder = self
.client
.request(
Method::POST,
format!("{}/api/search/{}", self.credentials.base_url, entity),
)
.json(criteria);

let response = self.handle_authenticated_request(request_builder)?;

if !response.status().is_success() {
let status = response.status();
Expand Down Expand Up @@ -276,14 +250,15 @@ impl SwClient {
}

pub fn index(&self, skip: Vec<String>) -> Result<(), SwApiError> {
let access_token = self.access_token.lock().unwrap().clone();

let response = self
let request_builder = self
.client
.post(format!("{}/api/_action/index", self.credentials.base_url))
.bearer_auth(access_token)
.json(&IndexBody { skip })
.send()?;
.request(
Method::POST,
format!("{}/api/_action/index", self.credentials.base_url),
)
.json(&IndexBody { skip });

let response = self.handle_authenticated_request(request_builder)?;

if !response.status().is_success() {
let status = response.status();
Expand Down Expand Up @@ -320,6 +295,54 @@ impl SwClient {
};
result
}

fn handle_authenticated_request(
&self,
request_builder: RequestBuilder,
) -> Result<Response, SwApiError> {
let mut retry_count = 0;
const MAX_RETRIES: u8 = 1;
let binding = request_builder.try_clone().unwrap().build().unwrap();
let path = binding.url().path();

loop {
let access_token = self.access_token.lock().unwrap().clone();
let request = request_builder
.try_clone()
.unwrap()
.bearer_auth(&access_token);

let start_time = Instant::now();
let response = request.send()?;

if response.status() == StatusCode::UNAUTHORIZED && retry_count < MAX_RETRIES {
// lock the access token
let mut access_token_guard = self.access_token.lock().unwrap();
// compare the access token with the one we used to make the request
if *access_token_guard != access_token {
// Another thread has already re-authenticated
continue;
}

// Perform re-authentication
let auth_response = Self::authenticate(&self.client, &self.credentials)?;
let new_token = auth_response.access_token;
*access_token_guard = new_token;

retry_count += 1;
continue;
}

let duration = start_time.elapsed();
println!(
"{} request finished after {} ms",
path,
duration.as_millis()
);

return Ok(response);
}
}
}

#[derive(Debug, Serialize)]
Expand Down

0 comments on commit 8709091

Please sign in to comment.