diff --git a/src/web-utils.test.ts b/src/web-utils.test.ts index 3d2867a1..db214fb5 100644 --- a/src/web-utils.test.ts +++ b/src/web-utils.test.ts @@ -182,6 +182,22 @@ describe("Url param extraction", () => { expect(paramObj["state"]).toStrictEqual(state); }); + it('should use query flag and ignore hash flag', () => { + const random = WebUtils.randomString(); + const foo = WebUtils.randomString(); + const paramObj = WebUtils.getUrlParams(`https://app.example.com?random=${random}&foo=${foo}#ignored`); + expect(paramObj["random"]).toStrictEqual(random); + expect(paramObj["foo"]).toStrictEqual(`${foo}#ignored`); + }); + + it('should use hash flag and ignore query flag', () => { + const random = WebUtils.randomString(); + const foo = WebUtils.randomString(); + const paramObj = WebUtils.getUrlParams(`https://app.example.com#random=${random}&foo=${foo}?ignored`); + expect(paramObj["random"]).toStrictEqual(random); + expect(paramObj["foo"]).toStrictEqual(`${foo}?ignored`); + }); + }); describe("Random string gen", () => { diff --git a/src/web-utils.ts b/src/web-utils.ts index 9fe42f67..edc5766c 100644 --- a/src/web-utils.ts +++ b/src/web-utils.ts @@ -59,29 +59,37 @@ export class WebUtils { /** * Public only for testing */ - static getUrlParams(urlString: string): any | undefined { - if (urlString && urlString.trim().length > 0) { - urlString = urlString.trim(); - let idx = urlString.indexOf("#"); - if (idx === -1) { - idx = urlString.indexOf("?"); - } - if (idx !== -1 && urlString.length > (idx + 1)) { - const urlParamStr = urlString.slice(idx + 1); - const keyValuePairs: string[] = urlParamStr.split(`&`); - return keyValuePairs.reduce((acc, hash) => { - const [key, val] = hash.split(`=`); - if (key && key.length > 0) { - return { - ...acc, - [key]: decodeURIComponent(val) - } - } - }, {}); - } + static getUrlParams(url: string): any | undefined { + const urlString = `${url}`.trim(); + + if (urlString.length === 0) { + return; + } + + let hashIndex = urlString.indexOf("#"); + let queryIndex = urlString.indexOf("?"); + if (hashIndex === -1 && queryIndex === -1) { + return; } - return undefined; + + const paramsIndex = hashIndex > -1 && hashIndex < queryIndex ? hashIndex : queryIndex; + + if (urlString.length <= paramsIndex + 1) { + return; + } + + const urlParamStr = urlString.slice(paramsIndex + 1); + const keyValuePairs: string[] = urlParamStr.split(`&`); + return keyValuePairs.reduce((acc, hash) => { + const [key, val] = hash.split(`=`); + if (key && key.length > 0) { + return { + ...acc, + [key]: decodeURIComponent(val) + } + } + }, {}); } static randomString(length: number = 10) { @@ -126,7 +134,7 @@ export class WebUtils { if (!webOptions.state || webOptions.state.length === 0) { webOptions.state = this.randomString(20); } - let mapHelper = this.getOverwritableValue<{[key: string]: string}>(configOptions, "additionalParameters"); + let mapHelper = this.getOverwritableValue<{ [key: string]: string }>(configOptions, "additionalParameters"); if (mapHelper) { webOptions.additionalParameters = {}; for (const key in mapHelper) { @@ -173,7 +181,7 @@ export class CryptoUtils { static toBase64(bytes: Uint8Array): string { let len = bytes.length; let base64 = ""; - for (let i = 0; i < len; i+=3) { + for (let i = 0; i < len; i += 3) { base64 += this.BASE64_CHARS[bytes[i] >> 2]; base64 += this.BASE64_CHARS[((bytes[i] & 3) << 4) | (bytes[i + 1] >> 4)]; base64 += this.BASE64_CHARS[((bytes[i + 1] & 15) << 2) | (bytes[i + 2] >> 6)]; @@ -225,6 +233,6 @@ export class WebOptions { pkceCodeChallenge: string; pkceCodeChallengeMethod: string; - additionalParameters: {[key: string]: string}; + additionalParameters: { [key: string]: string }; }