From 7a8deb4f9bf45fa814b12a14dcd32dfaf9fa969b Mon Sep 17 00:00:00 2001 From: c-bata Date: Sun, 12 Apr 2020 16:43:02 +0900 Subject: [PATCH] Fix IntUniformDistributions.Contains() --- distribution.go | 4 ++-- distribution_test.go | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/distribution.go b/distribution.go index 2e0aa439..d3be716a 100644 --- a/distribution.go +++ b/distribution.go @@ -102,11 +102,11 @@ func (d *IntUniformDistribution) Single() bool { // Contains to check a parameter value is contained in the range of this distribution. func (d *IntUniformDistribution) Contains(ir float64) bool { - value := int(ir) + value := d.ToExternalRepr(ir).(int) if d.Single() { return value == d.Low } - return d.Low <= value && value < d.High + return d.Low <= value && value <= d.High } // StepIntUniformDistributionName is the identifier name of IntUniformDistribution diff --git a/distribution_test.go b/distribution_test.go index 5da9c101..de6acd24 100644 --- a/distribution_test.go +++ b/distribution_test.go @@ -150,7 +150,7 @@ func TestDistributionToExternalRepresentation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.ToExternalRepr(tt.args); !reflect.DeepEqual(got, tt.want) { - t.Errorf("UniformDistribution.ToExternalRepr() = %v, want %v", got, tt.want) + t.Errorf("Distribution.ToExternalRepr() = %v, want %v", got, tt.want) } }) } @@ -231,7 +231,7 @@ func TestDistributionSingle(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.Single(); got != tt.want { - t.Errorf("UniformDistribution.Single() = %v, want %v", got, tt.want) + t.Errorf("Distribution.Single() = %v, want %v", got, tt.want) } }) } @@ -281,11 +281,23 @@ func TestDistributionContains(t *testing.T) { want: false, }, { - name: "int uniform distribution true", + name: "int uniform distribution true 1", distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, args: 3, want: true, }, + { + name: "int uniform distribution true 2", + distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, + args: -0.4999, + want: true, + }, + { + name: "int uniform distribution true 3", + distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, + args: 10.4999, + want: true, + }, { name: "int uniform distribution lower", distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10}, @@ -374,7 +386,7 @@ func TestDistributionContains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.distribution.Contains(tt.args); got != tt.want { - t.Errorf("UniformDistribution.ToInternalRepr() = %v, want %v", got, tt.want) + t.Errorf("Distribution.ToInternalRepr() = %v, want %v", got, tt.want) } }) }