Skip to content

Commit 04b16cd

Browse files
committed
Generic Memoization
1 parent 8086b04 commit 04b16cd

File tree

7 files changed

+213
-1
lines changed

7 files changed

+213
-1
lines changed

.vscode/launch.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Run Tests",
9+
"type": "go",
10+
"request": "launch",
11+
"mode": "test",
12+
"program": "${workspaceFolder}"
13+
}
14+
]
15+
}

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
# memoize2
1+
# Memoize
2+
3+
A generic memoization library.
4+
5+
Needs testing.

errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package memoize
2+
3+
import "github.com/pkg/errors"
4+
5+
var (
6+
ErrNotAFunc = errors.New("not a function")
7+
ErrMissingArgs = errors.New("target function must accept at least 1 argument")
8+
ErrMissingReturns = errors.New("target function must return at least 1 value")
9+
)

go.mod

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module github.com/coreyog/memoize
2+
3+
go 1.18
4+
5+
require (
6+
github.com/pkg/errors v0.9.1
7+
github.com/stretchr/testify v1.8.1
8+
)
9+
10+
require (
11+
github.com/davecgh/go-spew v1.1.1 // indirect
12+
github.com/pmezard/go-difflib v1.0.0 // indirect
13+
gopkg.in/yaml.v3 v3.0.1 // indirect
14+
)

go.sum

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
2+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
3+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
5+
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
6+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
7+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
8+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
9+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
10+
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
11+
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
12+
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
13+
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
14+
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
15+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
16+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
17+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
18+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
19+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

memoize.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package memoize
2+
3+
import "reflect"
4+
5+
func Memo[T any](fn T) (m T, err error) {
6+
fnt := reflect.TypeOf(fn)
7+
8+
if fnt.Kind() != reflect.Func {
9+
return fn, ErrNotAFunc
10+
}
11+
12+
if fnt.NumIn() == 0 {
13+
return fn, ErrMissingArgs
14+
}
15+
16+
if fnt.NumOut() == 0 {
17+
return fn, ErrMissingReturns
18+
}
19+
20+
fnv := reflect.ValueOf(fn)
21+
22+
cacheRoot := map[interface{}]interface{}{}
23+
24+
ret := reflect.MakeFunc(fnt, func(args []reflect.Value) (results []reflect.Value) {
25+
cResults := fillAndCheck(cacheRoot, args)
26+
if cResults == nil {
27+
results = fnv.Call(args)
28+
fillAndSet(cacheRoot, args, results)
29+
} else {
30+
results = cResults
31+
}
32+
33+
return results
34+
})
35+
36+
return ret.Interface().(T), nil
37+
}
38+
39+
func fillAndCheck(cacheRoot map[interface{}]interface{}, args []reflect.Value) (results []reflect.Value) {
40+
var m, next interface{}
41+
var ok bool
42+
m = cacheRoot
43+
44+
for _, arg := range args {
45+
next, ok = m.(map[interface{}]interface{})[arg.Interface()]
46+
if !ok {
47+
next = map[interface{}]interface{}{}
48+
m.(map[interface{}]interface{})[arg.Interface()] = next
49+
}
50+
51+
m = next
52+
}
53+
54+
results, ok = m.([]reflect.Value)
55+
56+
if !ok || results == nil {
57+
return nil
58+
}
59+
60+
return results
61+
}
62+
63+
func fillAndSet(cacheRoot map[interface{}]interface{}, args []reflect.Value, results []reflect.Value) {
64+
var m interface{}
65+
var prev map[interface{}]interface{}
66+
67+
m = cacheRoot
68+
69+
for _, arg := range args {
70+
prev = m.(map[interface{}]interface{})
71+
m = prev[arg.Interface()]
72+
}
73+
74+
prev[args[len(args)-1].Interface()] = results
75+
}

memoize_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package memoize
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestBadFuncErrors(t *testing.T) {
10+
_, err := Memo("not a func")
11+
assert.Equal(t, err, ErrNotAFunc)
12+
13+
noArgs := func() bool { return true }
14+
_, err = Memo(noArgs)
15+
assert.Equal(t, err, ErrMissingArgs)
16+
17+
noReturns := func(x int) {}
18+
_, err = Memo(noReturns)
19+
assert.Equal(t, err, ErrMissingReturns)
20+
}
21+
22+
func TestMultiCall(t *testing.T) {
23+
called := 0
24+
square := func(x int) int {
25+
called++
26+
return x * x
27+
}
28+
29+
m, err := Memo(square)
30+
assert.NoError(t, err)
31+
32+
for i := 0; i < 100; i++ {
33+
mResult := m(i)
34+
m2Result := m(i)
35+
squareResult := square(i)
36+
37+
assert.Equal(t, mResult, m2Result)
38+
assert.Equal(t, mResult, squareResult)
39+
}
40+
41+
assert.Equal(t, called, 200)
42+
}
43+
44+
func TestSimple(t *testing.T) {
45+
called := 0
46+
work := func(x int) int {
47+
called++
48+
return x
49+
}
50+
51+
m, err := Memo(work)
52+
assert.NoError(t, err)
53+
54+
for i := 0; i < 1000; i++ {
55+
m(0)
56+
}
57+
58+
assert.Equal(t, 1, called)
59+
}
60+
61+
func TestManyArgsManyRets(t *testing.T) {
62+
called := 0
63+
work := func(x int, y string, z float64) (int, string, float64) {
64+
called++
65+
return x, y, z
66+
}
67+
68+
m, err := Memo(work)
69+
assert.NoError(t, err)
70+
71+
for i := 0; i < 1000; i++ {
72+
m(0, "x", 3.14)
73+
}
74+
75+
assert.Equal(t, 1, called)
76+
}

0 commit comments

Comments
 (0)