1616
1717package  org .springframework .security .saml2 .provider .service .registration ;
1818
19- import  java .io .IOException ;
20- import  java .sql .PreparedStatement ;
2119import  java .sql .ResultSet ;
2220import  java .sql .SQLException ;
2321import  java .sql .Types ;
24- import  java .util .ArrayList ;
2522import  java .util .Collection ;
2623import  java .util .Iterator ;
2724import  java .util .List ;
28- import  java .util .function .Function ;
25+ import  java .util .function .Consumer ;
2926
3027import  org .apache .commons .logging .Log ;
3128import  org .apache .commons .logging .LogFactory ;
32- import  org .slf4j .Logger ;
33- import  org .slf4j .LoggerFactory ;
3429import  org .springframework .core .log .LogMessage ;
3530import  org .springframework .core .serializer .DefaultDeserializer ;
36- import  org .springframework .core .serializer .DefaultSerializer ;
3731import  org .springframework .core .serializer .Deserializer ;
38- import  org .springframework .core .serializer .Serializer ;
3932import  org .springframework .jdbc .core .ArgumentPreparedStatementSetter ;
4033import  org .springframework .jdbc .core .JdbcOperations ;
4134import  org .springframework .jdbc .core .PreparedStatementSetter ;
4437import  org .springframework .security .saml2 .core .Saml2X509Credential ;
4538import  org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration .AssertingPartyDetails ;
4639import  org .springframework .util .Assert ;
40+ import  org .springframework .util .StringUtils ;
4741
4842/** 
4943 * A JDBC implementation of {@link AssertingPartyMetadataRepository}. 
@@ -58,13 +52,9 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
5852	private  RowMapper <AssertingPartyMetadata > assertingPartyMetadataRowMapper  =
5953			new  AssertingPartyMetadataRowMapper (ResultSet ::getBytes );
6054
61- 	private  Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper  =
62- 			new  AssertingPartyMetadataParametersMapper ();
63- 
64- 	private  final  SetBytes  setBytes  = PreparedStatement ::setBytes ;
65- 
6655	// @formatter:off 
6756	static  final  String  COLUMN_NAMES  = "entity_id, " 
57+ 			+ "metadata_uri, " 
6858			+ "singlesignon_url, " 
6959			+ "singlesignon_binding, " 
7060			+ "singlesignon_sign_request, " 
@@ -87,26 +77,6 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
8777
8878	private  static  final  String  LOAD_ALL_SQL  = "SELECT "  + COLUMN_NAMES 
8979			+ " FROM "  + TABLE_NAME ;
90- 
91- 	private  static  final  String  SAVE_SQL  = "INSERT INTO "  + TABLE_NAME  + " (" 
92- 			+ COLUMN_NAMES 
93- 			+ ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ;
94- 	// @formatter:on 
95- 
96- 	private  static  final  String  DELETE_SQL  = "DELETE FROM "  + TABLE_NAME  + " WHERE "  + ENTITY_ID_FILTER ;
97- 
98- 	// @formatter:off 
99- 	private  static  final  String  UPDATE_SQL  = "UPDATE "  + TABLE_NAME 
100- 			+ " SET singlesignon_url = ?, "  +
101- 			"singlesignon_binding = ?, "  +
102- 			"singlesignon_sign_request = ?, "  +
103- 			"signing_algorithms = ?, "  +
104- 			"verification_credentials = ?, "  +
105- 			"encryption_credentials = ?, "  +
106- 			"singlelogout_url = ? ,"  +
107- 			"singlelogout_response_url = ?, "  +
108- 			"singlelogout_binding = ?" 
109- 			+ " WHERE "  + ENTITY_ID_FILTER ;
11080	// @formatter:on 
11181
11282	/** 
@@ -134,41 +104,6 @@ public void setAssertingPartyMetadataRowMapper(
134104		this .assertingPartyMetadataRowMapper  = assertingPartyMetadataRowMapper ;
135105	}
136106
137- 	public  void  setAssertingPartyMetadataParametersMapper (Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper ) {
138- 		Assert .notNull (assertingPartyMetadataParametersMapper , "assertingPartyMetadataParametersMapper cannot be null" );
139- 		this .assertingPartyMetadataParametersMapper  = assertingPartyMetadataParametersMapper ;
140- 	}
141- 
142- 	public  void  save (AssertingPartyMetadata  metadata ) {
143- 		Assert .notNull (metadata , "metadata cannot be null" );
144- 		int  rows  = update (metadata );
145- 		if  (rows  == 0 ) {
146- 			insert (metadata );
147- 		}
148- 	}
149- 
150- 	private  void  insert (AssertingPartyMetadata  metadata ) {
151- 		List <SqlParameterValue > parameters  = this .assertingPartyMetadataParametersMapper .apply (metadata );
152- 		PreparedStatementSetter  pss  = new  BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
153- 		this .jdbcOperations .update (SAVE_SQL , pss );
154- 	}
155- 
156- 	private  int  update (AssertingPartyMetadata  metadata ) {
157- 		List <SqlParameterValue > parameters  = this .assertingPartyMetadataParametersMapper .apply (metadata );
158- 		SqlParameterValue  credentialId  = parameters .remove (0 );
159- 		parameters .add (credentialId );
160- 		PreparedStatementSetter  pss  = new  BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
161- 		return  this .jdbcOperations .update (UPDATE_SQL , pss );
162- 	}
163- 
164- 	public  void  delete (String  entityId ) {
165- 		Assert .notNull (entityId , "entityId cannot be null" );
166- 		SqlParameterValue [] parameters  = new  SqlParameterValue []{
167- 				new  SqlParameterValue (Types .VARCHAR , entityId ),};
168- 		PreparedStatementSetter  pss  = new  ArgumentPreparedStatementSetter (parameters );
169- 		this .jdbcOperations .update (DELETE_SQL , pss );
170- 	}
171- 
172107	@ Override 
173108	public  AssertingPartyMetadata  findByEntityId (String  entityId ) {
174109		Assert .hasText (entityId , "entityId cannot be empty" );
@@ -187,75 +122,6 @@ public Iterator<AssertingPartyMetadata> iterator() {
187122		return  result .iterator ();
188123	}
189124
190- 	private  static  class  AssertingPartyMetadataParametersMapper 
191- 			implements  Function <AssertingPartyMetadata , List <SqlParameterValue >> {
192- 
193- 		private  final  Logger  logger  = LoggerFactory .getLogger (AssertingPartyMetadataParametersMapper .class );
194- 
195- 		private  final  Serializer <Object > serializer  = new  DefaultSerializer ();
196- 
197- 		@ Override 
198- 		public  List <SqlParameterValue > apply (AssertingPartyMetadata  record ) {
199- 			List <SqlParameterValue > parameters  = new  ArrayList <>();
200- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getEntityId ()));
201- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceLocation ()));
202- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceBinding ().getUrn ()));
203- 			parameters .add (new  SqlParameterValue (Types .BOOLEAN , record .getWantAuthnRequestsSigned ()));
204- 			try  {
205- 				parameters .add (new  SqlParameterValue (Types .BLOB ,
206- 						this .serializer .serializeToByteArray (record .getSigningAlgorithms ())));
207- 			} catch  (IOException  ex ) {
208- 				this .logger .debug ("Failed to serialize signing algorithms" , ex );
209- 				throw  new  IllegalArgumentException (ex );
210- 			}
211- 			try  {
212- 				parameters .add (new  SqlParameterValue (Types .BLOB ,
213- 						this .serializer .serializeToByteArray (record .getVerificationX509Credentials ())));
214- 			} catch  (IOException  ex ) {
215- 				this .logger .debug ("Failed to serialize verification credentials" , ex );
216- 				throw  new  IllegalArgumentException (ex );
217- 			}
218- 			try  {
219- 				parameters .add (new  SqlParameterValue (Types .BLOB ,
220- 						this .serializer .serializeToByteArray (record .getEncryptionX509Credentials ())));
221- 			} catch  (IOException  ex ) {
222- 				this .logger .debug ("Failed to serialize encryption credentials" , ex );
223- 				throw  new  IllegalArgumentException (ex );
224- 			}
225- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceLocation ()));
226- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceResponseLocation ()));
227- 			parameters .add (new  SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceBinding ().getUrn ()));
228- 			return  parameters ;
229- 		}
230- 	}
231- 
232- 	private  static  final  class  BlobArgumentPreparedStatementSetter  extends  ArgumentPreparedStatementSetter  {
233- 
234- 		private  final  SetBytes  setBytes ;
235- 
236- 		private  BlobArgumentPreparedStatementSetter (SetBytes  setBytes , Object [] args ) {
237- 			super (args );
238- 			this .setBytes  = setBytes ;
239- 		}
240- 
241- 		@ Override 
242- 		protected  void  doSetValue (PreparedStatement  ps , int  parameterPosition , Object  argValue ) throws  SQLException  {
243- 			if  (argValue  instanceof  SqlParameterValue  paramValue ) {
244- 				if  (paramValue .getSqlType () == Types .BLOB ) {
245- 					if  (paramValue .getValue () != null ) {
246- 						Assert .isInstanceOf (byte [].class , paramValue .getValue (),
247- 								"Value of blob parameter must be byte[]" );
248- 					}
249- 					byte [] valueBytes  = (byte []) paramValue .getValue ();
250- 					this .setBytes .setBytes (ps , parameterPosition , valueBytes );
251- 					return ;
252- 				}
253- 			}
254- 			super .doSetValue (ps , parameterPosition , argValue );
255- 		}
256- 
257- 	}
258- 
259125	/** 
260126	 * The default {@link RowMapper} that maps the current row in 
261127	 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. 
@@ -275,61 +141,68 @@ private final static class AssertingPartyMetadataRowMapper implements RowMapper<
275141		@ Override 
276142		public  AssertingPartyMetadata  mapRow (ResultSet  rs , int  rowNum ) throws  SQLException  {
277143			String  entityId  = rs .getString ("entity_id" );
144+ 			String  metadataUri  = rs .getString ("metadata_uri" );
278145			String  singleSignOnUrl  = rs .getString ("singlesignon_url" );
279- 			Saml2MessageBinding  singleSignOnBinding  = Saml2MessageBinding 
280- 					.from (rs .getString ("singlesignon_binding" ));
146+ 			Saml2MessageBinding  singleSignOnBinding  = Saml2MessageBinding .from (rs .getString ("singlesignon_binding" ));
281147			boolean  singleSignOnSignRequest  = rs .getBoolean ("singlesignon_sign_request" );
282- 			List <String > signingAlgorithms ;
283- 			try  {
284- 				signingAlgorithms  = (List <String >) deserializer .deserializeFromByteArray (
285- 						this .getBytes .getBytes (rs , "signing_algorithms" ));
286- 			} catch  (IOException  ex ) {
287- 				this .logger .debug (
288- 						LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
289- 				return  null ;
290- 			}
291- 			Collection <Saml2X509Credential > verificationCredentials ;
292- 			try  {
293- 				verificationCredentials  = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
294- 						this .getBytes .getBytes (rs , "verification_credentials" ));
295- 			} catch  (IOException  ex ) {
296- 				this .logger .debug (
297- 						LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
298- 				return  null ;
299- 			}
300- 			Collection <Saml2X509Credential > encryptionCredentials ;
148+ 			String  singleLogoutUrl  = rs .getString ("singlelogout_url" );
149+ 			String  singleLogoutResponseUrl  = rs .getString ("singlelogout_response_url" );
150+ 			Saml2MessageBinding  singleLogoutBinding  = Saml2MessageBinding .from (rs .getString ("singlelogout_binding" ));
151+ 			byte [] signingAlgorithmsBytes  = this .getBytes .getBytes (rs , "signing_algorithms" );
152+ 			byte [] verificationCredentialsBytes  = this .getBytes .getBytes (rs , "verification_credentials" );
153+ 			byte [] encryptionCredentialsBytes  = this .getBytes .getBytes (rs , "encryption_credentials" );
154+ 
155+ 			boolean  usingMetadata  = StringUtils .hasText (metadataUri );
156+ 			AssertingPartyMetadata .Builder <?> builder  = (!usingMetadata ) ? new  AssertingPartyDetails .Builder ().entityId (entityId )
157+ 					: createBuilderUsingMetadata (entityId , metadataUri );
301158			try  {
302- 				encryptionCredentials  = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
303- 						this .getBytes .getBytes (rs , "encryption_credentials" ));
304- 			} catch  (IOException  ex ) {
159+ 				if  (signingAlgorithmsBytes  != null ) {
160+ 					List <String > signingAlgorithms  = (List <String >) deserializer .deserializeFromByteArray (signingAlgorithmsBytes );
161+ 					builder .signingAlgorithms (algorithms  -> algorithms .addAll (signingAlgorithms ));
162+ 				}
163+ 				if  (verificationCredentialsBytes  != null ) {
164+ 					Collection <Saml2X509Credential > verificationCredentials  = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (verificationCredentialsBytes );
165+ 					builder .verificationX509Credentials (credentials  -> credentials .addAll (verificationCredentials ));
166+ 				}
167+ 				if  (encryptionCredentialsBytes  != null ) {
168+ 					Collection <Saml2X509Credential > encryptionCredentials  = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (encryptionCredentialsBytes );
169+ 					builder .encryptionX509Credentials (credentials  -> credentials .addAll (encryptionCredentials ));
170+ 				}
171+ 			} catch  (Exception  ex ) {
305172				this .logger .debug (
306- 						LogMessage .format ("Encryption  credentials of  %s could not be parsed. " , entityId ), ex );
173+ 						LogMessage .format ("Parsing serialized  credentials for entity  %s failed " , entityId ), ex );
307174				return  null ;
308175			}
309- 			String  singleLogoutUrl  = rs .getString ("singlelogout_url" );
310- 			String  singleLogoutResponseUrl  = rs .getString ("singlelogout_response_url" );
311- 			Saml2MessageBinding  singleLogoutBinding  = Saml2MessageBinding 
312- 					.from (rs .getString ("singlelogout_binding" ));
313176
314- 			return  new  AssertingPartyDetails .Builder ()
315- 					.entityId (entityId )
316- 					.wantAuthnRequestsSigned (singleSignOnSignRequest )
317- 					.signingAlgorithms (algorithms  -> algorithms .addAll (signingAlgorithms ))
318- 					.verificationX509Credentials (credentials  -> credentials .addAll (verificationCredentials ))
319- 					.encryptionX509Credentials (credentials  -> credentials .addAll (encryptionCredentials ))
320- 					.singleSignOnServiceLocation (singleSignOnUrl )
321- 					.singleSignOnServiceBinding (singleSignOnBinding )
322- 					.singleLogoutServiceLocation (singleLogoutUrl )
323- 					.singleLogoutServiceBinding (singleLogoutBinding )
324- 					.singleLogoutServiceResponseLocation (singleLogoutResponseUrl )
325- 					.build ();
177+ 			applyingWhenNonNull (singleSignOnUrl , builder ::singleSignOnServiceLocation );
178+ 			applyingWhenNonNull (singleSignOnBinding , builder ::singleSignOnServiceBinding );
179+ 			applyingWhenNonNull (singleSignOnSignRequest , builder ::wantAuthnRequestsSigned );
180+ 			applyingWhenNonNull (singleLogoutUrl , builder ::singleLogoutServiceLocation );
181+ 			applyingWhenNonNull (singleLogoutResponseUrl , builder ::singleLogoutServiceResponseLocation );
182+ 			applyingWhenNonNull (singleLogoutBinding , builder ::singleLogoutServiceBinding );
183+ 			return  builder .build ();
326184		}
327- 	}
328185
329- 	private  interface  SetBytes  {
186+ 		private  <T > void  applyingWhenNonNull (T  value , Consumer <T > consumer ) {
187+ 			if  (value  != null ) {
188+ 				consumer .accept (value );
189+ 			}
190+ 		}
330191
331- 		void  setBytes (PreparedStatement  ps , int  index , byte [] bytes ) throws  SQLException ;
192+ 		private  AssertingPartyMetadata .Builder <?> createBuilderUsingMetadata (String  entityId , String  metadataUri ) {
193+ 			Collection <AssertingPartyMetadata .Builder <?>> candidates  = AssertingPartyMetadata 
194+ 					.collectionFromMetadataLocation (metadataUri );
195+ 			for  (AssertingPartyMetadata .Builder <?> candidate  : candidates ) {
196+ 				if  (entityId  == null  || entityId .equals (getEntityId (candidate ))) {
197+ 					return  candidate ;
198+ 				}
199+ 			}
200+ 			throw  new  IllegalStateException ("No asserting party metadata with Entity ID '"  + entityId  + "' found" );
201+ 		}
332202
203+ 		private  Object  getEntityId (AssertingPartyMetadata .Builder <?> candidate ) {
204+ 			return  candidate .build ().getEntityId ();
205+ 		}
333206	}
334207
335208	private  interface  GetBytes  {
0 commit comments