diff --git a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationParameter.java b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationParameter.java index d60efafe9ec..ca80f66103b 100644 --- a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationParameter.java +++ b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationParameter.java @@ -34,6 +34,9 @@ public class ValidationParameter implements Serializable { private static final long serialVersionUID = 7158911668568000392L; + @NotNull(groups = ValidationService.Update.class) + private Integer id; + @NotNull // 不允许为空 @Size(min = 2, max = 20) // 长度或大小范围 private String name; @@ -52,6 +55,14 @@ public class ValidationParameter implements Serializable { @Future // 必须为一个未来的时间 private Date expiryDate; + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + public String getName() { return name; } diff --git a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationService.java b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationService.java index 519d2801406..eb9ea2d35ed 100644 --- a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationService.java +++ b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationService.java @@ -15,6 +15,8 @@ */ package com.alibaba.dubbo.config.validation; +import com.alibaba.dubbo.validation.MethodValidated; + import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; import javax.validation.constraints.Pattern; @@ -28,16 +30,31 @@ */ public interface ValidationService { // 缺省可按服务接口区分验证场景,如:@NotNull(groups = ValidationService.class) + /** + * 没有加上“@MethodValidated(ValidationService.Save.class)”这句代码时, + * 现在的检查逻辑不会去检验groups = ValidationService.Save.class这个分组 + * + * @param parameter + */ + @MethodValidated(Save.class) void save(ValidationParameter parameter); void update(ValidationParameter parameter); void delete(@Min(1) long id, @NotNull @Size(min = 2, max = 16) @Pattern(regexp = "^[a-zA-Z]+$") String operator); + /** + * 假设关联查询的时候需要同时传id和email的值。这时需要检查Sava分组和Update分组。 + * @param parameter + */ + @MethodValidated({Save.class, Update.class}) + void relatedQuery(ValidationParameter parameter); + @interface Save { } // 与方法同名接口,首字母大写,用于区分验证场景,如:@NotNull(groups = ValidationService.Save.class),可选 @interface Update { } // 与方法同名接口,首字母大写,用于区分验证场景,如:@NotNull(groups = ValidationService.Update.class),可选 + } diff --git a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationServiceImpl.java b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationServiceImpl.java index be45b4d9217..2fc0a106ae7 100644 --- a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationServiceImpl.java +++ b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationServiceImpl.java @@ -31,4 +31,8 @@ public void update(ValidationParameter parameter) { public void delete(long id, String operator) { } + public void relatedQuery(ValidationParameter parameter){ + + } + } diff --git a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationTest.java b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationTest.java index 0c075934db4..7a90dc9af42 100644 --- a/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationTest.java +++ b/dubbo-config/dubbo-config-api/src/test/java/com/alibaba/dubbo/config/validation/ValidationTest.java @@ -81,6 +81,36 @@ public void testValidation() { Assert.assertNotNull(violations); } + //检查Save分组 save error + try { + parameter = new ValidationParameter(); + parameter.setName("liangfei"); + parameter.setAge(50); + parameter.setLoginDate(new Date(System.currentTimeMillis() - 1000000)); + parameter.setExpiryDate(new Date(System.currentTimeMillis() + 1000000)); + validationService.save(parameter); + Assert.fail(); + } catch (RpcException e) { + ConstraintViolationException ve = (ConstraintViolationException) e.getCause(); + Set> violations = ve.getConstraintViolations(); + Assert.assertNotNull(violations); + } + + // relatedQuery error 不传id和email的值,触发Save和Update的检查异常 + try { + parameter = new ValidationParameter(); + parameter.setName("liangfei"); + parameter.setAge(50); + parameter.setLoginDate(new Date(System.currentTimeMillis() - 1000000)); + parameter.setExpiryDate(new Date(System.currentTimeMillis() + 1000000)); + validationService.relatedQuery(parameter); + Assert.fail(); + } catch (RpcException e) { + ConstraintViolationException ve = (ConstraintViolationException) e.getCause(); + Set> violations = ve.getConstraintViolations(); + Assert.assertEquals(violations.size(),2); + } + // Save Error try { parameter = new ValidationParameter(); diff --git a/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/MethodValidated.java b/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/MethodValidated.java new file mode 100644 index 00000000000..2302ec8be3c --- /dev/null +++ b/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/MethodValidated.java @@ -0,0 +1,23 @@ +package com.alibaba.dubbo.validation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * 方法分组验证注解 + *

使用场景:当调用某个方法时,需要检查多个分组,可以在接口方法上加上该注解


+ * 用法:
   @MethodValidated({Save.class, Update.class})
+ *  void relatedQuery(ValidationParameter parameter);
+ * 在接口方法上增加注解,表示relatedQuery这个方法需要同时检查Save和Update这两个分组 + * + * @author: zhangyinyue + */ +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface MethodValidated { + Class[] value() default {}; +} diff --git a/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/support/jvalidation/JValidator.java b/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/support/jvalidation/JValidator.java index affe8629cbf..e92ad65bc82 100644 --- a/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/support/jvalidation/JValidator.java +++ b/dubbo-filter/dubbo-filter-validation/src/main/java/com/alibaba/dubbo/validation/support/jvalidation/JValidator.java @@ -20,6 +20,7 @@ import com.alibaba.dubbo.common.logger.Logger; import com.alibaba.dubbo.common.logger.LoggerFactory; import com.alibaba.dubbo.common.utils.ReflectUtils; +import com.alibaba.dubbo.validation.MethodValidated; import com.alibaba.dubbo.validation.Validator; import javassist.ClassPool; @@ -55,9 +56,12 @@ import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Date; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -231,51 +235,60 @@ else if (memberValue instanceof ArrayMemberValue) { } public void validate(String methodName, Class[] parameterTypes, Object[] arguments) throws Exception { + List> groups = new ArrayList>(); String methodClassName = clazz.getName() + "$" + toUpperMethoName(methodName); Class methodClass = null; try { methodClass = Class.forName(methodClassName, false, Thread.currentThread().getContextClassLoader()); + groups.add(methodClass); } catch (ClassNotFoundException e) { } Set> violations = new HashSet>(); Method method = clazz.getMethod(methodName, parameterTypes); + Class[] methodClasses = null; + if (method.isAnnotationPresent(MethodValidated.class)){ + methodClasses = method.getAnnotation(MethodValidated.class).value(); + groups.addAll(Arrays.asList(methodClasses)); + } + //加入默认分组 + groups.add(0, Default.class); + groups.add(1, clazz); + + //将list转换为数组 + Class[] classgroups = groups.toArray(new Class[0]); + Object parameterBean = getMethodParameterBean(clazz, method, arguments); if (parameterBean != null) { - if (methodClass != null) { - violations.addAll(validator.validate(parameterBean, Default.class, clazz, methodClass)); - } else { - violations.addAll(validator.validate(parameterBean, Default.class, clazz)); - } + violations.addAll(validator.validate(parameterBean, classgroups )); } + for (Object arg : arguments) { - validate(violations, arg, clazz, methodClass); + validate(violations, arg, classgroups); } + if (violations.size() > 0) { + logger.error("Failed to validate service: " + clazz.getName() + ", method: " + methodName + ", cause: " + violations); throw new ConstraintViolationException("Failed to validate service: " + clazz.getName() + ", method: " + methodName + ", cause: " + violations, violations); } } - private void validate(Set> violations, Object arg, Class clazz, Class methodClass) { + private void validate(Set> violations, Object arg, Class... groups) { if (arg != null && !isPrimitives(arg.getClass())) { if (Object[].class.isInstance(arg)) { for (Object item : (Object[]) arg) { - validate(violations, item, clazz, methodClass); + validate(violations, item, groups); } } else if (Collection.class.isInstance(arg)) { for (Object item : (Collection) arg) { - validate(violations, item, clazz, methodClass); + validate(violations, item, groups); } } else if (Map.class.isInstance(arg)) { for (Map.Entry entry : ((Map) arg).entrySet()) { - validate(violations, entry.getKey(), clazz, methodClass); - validate(violations, entry.getValue(), clazz, methodClass); + validate(violations, entry.getKey(), groups); + validate(violations, entry.getValue(), groups); } } else { - if (methodClass != null) { - violations.addAll(validator.validate(arg, Default.class, clazz, methodClass)); - } else { - violations.addAll(validator.validate(arg, Default.class, clazz)); - } + violations.addAll(validator.validate(arg, groups)); } } }