diff --git a/distribution.go b/distribution.go index 2e0aa439..d3be716a 100644 --- a/distribution.go +++ b/distribution.go @@ -102,11 +102,11 @@ func (d *IntUniformDistribution) Single() bool { // Contains to check a parameter value is contained in the range of this distribution. func (d *IntUniformDistribution) Contains(ir float64) bool { - value := int(ir) + value := d.ToExternalRepr(ir).(int) if d.Single() { return value == d.Low } - return d.Low <= value && value < d.High + return d.Low <= value && value <= d.High } // StepIntUniformDistributionName is the identifier name of IntUniformDistribution diff --git a/distribution_test.go b/distribution_test.go index 5da9c101..de6acd24 100644 --- a/distribution_test.go +++ b/distribution_test.go @@ -150,7 +150,7 @@ func TestDistributionToExternalRepresentation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.ToExternalRepr(tt.args); !reflect.DeepEqual(got, tt.want) { - t.Errorf("UniformDistribution.ToExternalRepr() = %v, want %v", got, tt.want) + t.Errorf("Distribution.ToExternalRepr() = %v, want %v", got, tt.want) } }) } @@ -231,7 +231,7 @@ func TestDistributionSingle(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.Single(); got != tt.want { - t.Errorf("UniformDistribution.Single() = %v, want %v", got, tt.want) + t.Errorf("Distribution.Single() = %v, want %v", got, tt.want) } }) } @@ -281,11 +281,23 @@ func TestDistributionContains(t *testing.T) { want: false, }, { - name: "int uniform distribution true", + name: "int uniform distribution true 1", distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, args: 3, want: true, }, + { + name: "int uniform distribution true 2", + distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, + args: -0.4999, + want: true, + }, + { + name: "int uniform distribution true 3", + distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, + args: 10.4999, + want: true, + }, { name: "int uniform distribution lower", distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, @@ -374,7 +386,7 @@ func TestDistributionContains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.Contains(tt.args); got != tt.want { - t.Errorf("UniformDistribution.ToInternalRepr() = %v, want %v", got, tt.want) + t.Errorf("Distribution.ToInternalRepr() = %v, want %v", got, tt.want) } }) }