@@ -378,5 +378,61 @@ def index():
378
378
self .assertEqual (resp .status_code , 200 )
379
379
380
380
381
+ class AppExtensionPlusInPath (FlaskCorsTestCase ):
382
+ '''
383
+ Regression test for CVE-2024-6844:
384
+ Ensures that we correctly differentiate '+' from ' ' in URL paths.
385
+ '''
386
+
387
+ def setUp (self ):
388
+ self .app = Flask (__name__ )
389
+ CORS (self .app , resources = {
390
+ r'/service\+path' : {'origins' : ['http://foo.com' ]},
391
+ r'/service path' : {'origins' : ['http://bar.com' ]},
392
+ })
393
+
394
+ @self .app .route ('/service+path' )
395
+ def plus_path ():
396
+ return 'plus'
397
+
398
+ @self .app .route ('/service path' )
399
+ def space_path ():
400
+ return 'space'
401
+
402
+ self .client = self .app .test_client ()
403
+
404
+ def test_plus_path_origin_allowed (self ):
405
+ '''
406
+ Ensure that CORS matches + literally and allows the correct origin
407
+ '''
408
+ response = self .client .get ('/service+path' , headers = {'Origin' : 'http://foo.com' })
409
+ self .assertEqual (response .status_code , 200 )
410
+ self .assertEqual (response .headers .get (ACL_ORIGIN ), 'http://foo.com' )
411
+
412
+ def test_space_path_origin_allowed (self ):
413
+ '''
414
+ Ensure that CORS treats /service path differently and allows correct origin
415
+ '''
416
+ response = self .client .get ('/service%20path' , headers = {'Origin' : 'http://bar.com' })
417
+ self .assertEqual (response .status_code , 200 )
418
+ self .assertEqual (response .headers .get (ACL_ORIGIN ), 'http://bar.com' )
419
+
420
+ def test_plus_path_rejects_other_origin (self ):
421
+ '''
422
+ Origin not allowed for + path should be rejected
423
+ '''
424
+ response = self .client .get ('/service+path' , headers = {'Origin' : 'http://bar.com' })
425
+ self .assertEqual (response .status_code , 200 )
426
+ self .assertIsNone (response .headers .get (ACL_ORIGIN ))
427
+
428
+ def test_space_path_rejects_other_origin (self ):
429
+ '''
430
+ Origin not allowed for space path should be rejected
431
+ '''
432
+ response = self .client .get ('/service%20path' , headers = {'Origin' : 'http://foo.com' })
433
+ self .assertEqual (response .status_code , 200 )
434
+ self .assertIsNone (response .headers .get (ACL_ORIGIN ))
435
+
436
+
381
437
if __name__ == "__main__" :
382
438
unittest .main ()
0 commit comments