diff --git a/extensions/record.go b/extensions/record.go index 7c7904e..5c8bc81 100644 --- a/extensions/record.go +++ b/extensions/record.go @@ -133,19 +133,42 @@ func (t *SexpRecord) MarshalJSON() ([]byte, error) { } func (t *SexpRecord) Explain(env *glisp.Environment, sym string, args []glisp.Sexp) (glisp.Sexp, error) { - if len(args) > 1 { + switch len(args) { + case 0: + return t.GetField(sym) + case 1: + return t.GetFieldDefault(sym, args[0]), nil + default: return glisp.SexpNull, fmt.Errorf("record field accessor need not more than one argument but got %v", len(args)) } - if _, ok := t.class.fieldsMeta[sym]; !ok { - return glisp.SexpNull, fmt.Errorf("record<%s> not have a field named %s", t.TypeName(), string(sym)) - } - fv, _ := t.value.HashGet(glisp.SexpStr(sym)) - if len(args) == 1 && fv == glisp.SexpNull { - fv = args[0] +} + +func (t *SexpRecord) GetField(name string) (glisp.Sexp, error) { + if _, ok := t.class.fieldsMeta[name]; !ok { + return glisp.SexpNull, fmt.Errorf("record<%s> not have a field named %s", t.TypeName(), name) } + fv, _ := t.value.HashGet(glisp.SexpStr(name)) return fv, nil } +func (t *SexpRecord) GetFieldDefault(name string, defaultVal glisp.Sexp) glisp.Sexp { + ret, err := t.GetField(name) + if err != nil || ret == glisp.SexpNull { + return defaultVal + } + return ret +} + +func (t *SexpRecord) SetField(name string, val glisp.Sexp) error { + if f, ok := t.class.fieldsMeta[name]; ok { + if err := checkTypeMatched(f.Type, val); err != nil { + return err + } + return t.value.HashSet(glisp.SexpStr(f.Name), val) + } + return fmt.Errorf("field %s not found", name) +} + func IsRecord(args glisp.Sexp) bool { _, ok := args.(*SexpRecord) return ok @@ -241,14 +264,7 @@ func AssocRecordField(name string) glisp.UserFunction { return glisp.SexpNull, fmt.Errorf("second argument must be symbol but got %v", glisp.InspectType(args[1])) } record, field := args[0].(*SexpRecord), args[1].(glisp.SexpSymbol) - if f, ok := record.class.fieldsMeta[field.Name()]; ok { - if err := checkTypeMatched(f.Type, args[2]); err != nil { - return glisp.SexpNull, err - } - record.value.HashSet(glisp.SexpStr(f.Name), args[2]) - return record, nil - } - return glisp.SexpNull, fmt.Errorf("field %s not found", field.Name()) + return glisp.SexpNull, record.SetField(field.Name(), args[2]) } } @@ -272,11 +288,15 @@ func CheckIsRecordOf(name string) glisp.UserFunction { if !IsRecordClass(args[1]) { return glisp.SexpNull, fmt.Errorf("second argument must be record class but got %s", glisp.InspectType(args[1])) } - r, cls := args[0].(*SexpRecord), args[1].(SexpRecordClass) - return glisp.SexpBool(r.class.TypeName() == cls.TypeName()), nil + cls := args[1].(SexpRecordClass) + return glisp.SexpBool(IsRecordOf(args[0], cls.TypeName())), nil } } +func IsRecordOf(r glisp.Sexp, typ string) bool { + return IsRecord(r) && r.(*SexpRecord).class.TypeName() == typ +} + /* (defrecord MyType (name type) (name2 type2) ) */ func DefineRecord(name string) glisp.UserFunction { return func(env *glisp.Environment, args []glisp.Sexp) (glisp.Sexp, error) { @@ -397,3 +417,40 @@ func paddingRight(str string, max int) string { } return str } + +type RecordClassBuilder struct { + cls *sexpRecordClass +} + +func NewRecordClassBuilder(className string) *RecordClassBuilder { + return &RecordClassBuilder{cls: &sexpRecordClass{typeName: className, fieldsMeta: make(map[string]sexpRecordField)}} +} + +func (b *RecordClassBuilder) AddField(name string, typ string) *RecordClassBuilder { + if _, ok := b.cls.fieldsMeta[name]; ok { + return b + } + b.cls.fieldNames = append(b.cls.fieldNames, name) + b.cls.fieldsMeta[name] = sexpRecordField{Name: name, Type: typ} + return b +} + +func (b *RecordClassBuilder) Build(env *glisp.Environment) SexpRecordClass { + env.Bind(b.cls.typeName, b.cls) + env.AddMacro(b.cls.constructorName(), func(_e *glisp.Environment, _args []glisp.Sexp) (glisp.Sexp, error) { + lb := glisp.NewListBuilder() + lb.Add(b.cls.getConstructor()) + for i := range _args { + if i%2 == 0 { + lb.Add(glisp.MakeList([]glisp.Sexp{ + _e.MakeSymbol("quote"), + _args[i], + })) + } else { + lb.Add(_args[i]) + } + } + return lb.Get(), nil + }) + return b.cls +} diff --git a/tests/glisp_test.go b/tests/glisp_test.go index 1bc0f61..97c4f33 100644 --- a/tests/glisp_test.go +++ b/tests/glisp_test.go @@ -661,3 +661,14 @@ func TestListFuzzyMacroName(t *testing.T) { testMacro(`(fuzzy-123)`, `number:123`) testMacro(`(fuzzy-456)`, `number:456`) } + +func TestDefineClassInGolang(t *testing.T) { + vm := loadAllExtensions(glisp.New()) + extensions.NewRecordClassBuilder("test/Options"). + AddField("Name", "string"). + AddField("Age", "int"). + Build(vm) + script := `(def p (->test/Options Name "hello" Age (+ 1 2))) (assert (= "hello" (:Name p))) (assert (= 3 (:Age p)))` + _, err := vm.EvalString(script) + ExpectSuccess(t, err) +}