@@ -73,6 +73,14 @@ def login_id_thirdparty_from_phone(identifier):
73
73
return {"type" : "m.id.thirdparty" , "medium" : "msisdn" , "address" : msisdn }
74
74
75
75
76
+ def build_service_param (cas_service_url , client_redirect_url ):
77
+ return "%s%s?redirectUrl=%s" % (
78
+ cas_service_url ,
79
+ "/_matrix/client/r0/login/cas/ticket" ,
80
+ urllib .parse .quote (client_redirect_url , safe = "" ),
81
+ )
82
+
83
+
76
84
class LoginRestServlet (RestServlet ):
77
85
PATTERNS = client_patterns ("/login$" , v1 = True )
78
86
CAS_TYPE = "m.login.cas"
@@ -428,18 +436,15 @@ def get_sso_url(self, client_redirect_url):
428
436
class CasRedirectServlet (BaseSSORedirectServlet ):
429
437
def __init__ (self , hs ):
430
438
super (CasRedirectServlet , self ).__init__ ()
431
- self .cas_server_url = hs .config .cas_server_url . encode ( "ascii" )
432
- self .cas_service_url = hs .config .cas_service_url . encode ( "ascii" )
439
+ self .cas_server_url = hs .config .cas_server_url
440
+ self .cas_service_url = hs .config .cas_service_url
433
441
434
442
def get_sso_url (self , client_redirect_url ):
435
- client_redirect_url_param = urllib .parse .urlencode (
436
- {b"redirectUrl" : client_redirect_url }
437
- ).encode ("ascii" )
438
- hs_redirect_url = self .cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
439
- service_param = urllib .parse .urlencode (
440
- {b"service" : b"%s?%s" % (hs_redirect_url , client_redirect_url_param )}
441
- ).encode ("ascii" )
442
- return b"%s/login?%s" % (self .cas_server_url , service_param )
443
+ args = urllib .parse .urlencode (
444
+ {"service" : build_service_param (self .cas_service_url , client_redirect_url )}
445
+ )
446
+
447
+ return "%s/login?%s" % (self .cas_server_url , args )
443
448
444
449
445
450
class CasTicketServlet (RestServlet ):
@@ -448,10 +453,7 @@ class CasTicketServlet(RestServlet):
448
453
def __init__ (self , hs ):
449
454
super (CasTicketServlet , self ).__init__ ()
450
455
self .cas_server_url = hs .config .cas_server_url
451
- self .cas_service_url = (
452
- hs .config .cas_service_url .encode ("ascii" )
453
- + b"/_matrix/client/r0/login/cas/ticket?redirectUrl="
454
- )
456
+ self .cas_service_url = hs .config .cas_service_url
455
457
self .cas_displayname_attribute = hs .config .cas_displayname_attribute
456
458
self .cas_required_attributes = hs .config .cas_required_attributes
457
459
self ._sso_auth_handler = SSOAuthHandler (hs )
@@ -460,12 +462,9 @@ def __init__(self, hs):
460
462
async def on_GET (self , request ):
461
463
client_redirect_url = parse_string (request , "redirectUrl" , required = True )
462
464
uri = self .cas_server_url + "/proxyValidate"
463
- service_url = self .cas_service_url + urllib .parse .quote (
464
- client_redirect_url , safe = ""
465
- ).encode ("ascii" )
466
465
args = {
467
466
"ticket" : parse_string (request , "ticket" , required = True ),
468
- "service" : service_url ,
467
+ "service" : build_service_param ( self . cas_service_url , client_redirect_url ) ,
469
468
}
470
469
try :
471
470
body = await self ._http_client .get_raw (uri , args )
0 commit comments