From 1a755b73ebe895d24a17a5ece8fdd6355012a8ca Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Sun, 10 Jul 2022 22:33:01 +0100 Subject: [PATCH] Add type checking for class statements --- interpret/env.go | 15 ++-- interpret/interpreter.go | 2 +- parse/parser.go | 2 +- typechecker/testdata/class.golden | 0 typechecker/testdata/class.input | 8 ++ typechecker/typechecker.go | 129 +++++++++++++++++++++++++----- 6 files changed, 131 insertions(+), 25 deletions(-) create mode 100644 typechecker/testdata/class.golden create mode 100644 typechecker/testdata/class.input diff --git a/interpret/env.go b/interpret/env.go index 294b41d..b6b7c83 100644 --- a/interpret/env.go +++ b/interpret/env.go @@ -1,11 +1,16 @@ package interpret import ( - "fmt" + "errors" "github.com/chidiwilliams/glox/ast" ) +// TODO: move this to an env package + +// ErrUndefined is returned when retrieving or assigning to an undefined variable +var ErrUndefined = errors.New("undefined variable") + type Environment struct { Enclosing *Environment values map[string]interface{} @@ -29,14 +34,14 @@ func (e *Environment) Has(name string) bool { return false } -func (e *Environment) Get(name ast.Token) (interface{}, error) { - if val, ok := e.values[name.Lexeme]; ok { +func (e *Environment) Get(name string) (interface{}, error) { + if val, ok := e.values[name]; ok { return val, nil } if e.Enclosing != nil { return e.Enclosing.Get(name) } - return nil, runtimeError{name, fmt.Sprintf("Undefined variable '%s'", name.Lexeme)} + return nil, ErrUndefined } func (e *Environment) assign(name ast.Token, value interface{}) error { @@ -47,7 +52,7 @@ func (e *Environment) assign(name ast.Token, value interface{}) error { if e.Enclosing != nil { return e.Enclosing.assign(name, value) } - return runtimeError{name, fmt.Sprintf("Undefined variable '%s'", name.Lexeme)} + return ErrUndefined } func (e *Environment) GetAt(distance int, name string) interface{} { diff --git a/interpret/interpreter.go b/interpret/interpreter.go index 16c7b80..4acf210 100644 --- a/interpret/interpreter.go +++ b/interpret/interpreter.go @@ -309,7 +309,7 @@ func (in *Interpreter) lookupVariable(name ast.Token, expr ast.Expr) (interface{ if distance, ok := in.GetLocalDistance(expr); ok { return in.environment.GetAt(distance, name.Lexeme), nil } - return in.globals.Get(name) + return in.globals.Get(name.Lexeme) } func (in *Interpreter) VisitBinaryExpr(expr ast.BinaryExpr) interface{} { diff --git a/parse/parser.go b/parse/parser.go index 12a0394..24dc7d2 100644 --- a/parse/parser.go +++ b/parse/parser.go @@ -649,7 +649,7 @@ func (p *Parser) error(token ast.Token, message string) { where = " at '" + token.Lexeme + "'" } - err := parseError{msg: fmt.Sprintf("[line %d] Error%s: %s\n", token.Line, where, message)} + err := parseError{msg: fmt.Sprintf("[line %d] Error%s: %s\n", token.Line+1, where, message)} _, _ = p.stdErr.Write([]byte(err.Error())) panic(err) } diff --git a/typechecker/testdata/class.golden b/typechecker/testdata/class.golden new file mode 100644 index 0000000..e69de29 diff --git a/typechecker/testdata/class.input b/typechecker/testdata/class.input new file mode 100644 index 0000000..401a9d4 --- /dev/null +++ b/typechecker/testdata/class.input @@ -0,0 +1,8 @@ +class Greeter { + greet(): string { + return "hello"; + } +} + +var greeter: Greeter = Greeter(); +greeter; diff --git a/typechecker/typechecker.go b/typechecker/typechecker.go index 9079cfa..1d2f47a 100644 --- a/typechecker/typechecker.go +++ b/typechecker/typechecker.go @@ -46,7 +46,7 @@ var ( TypeNil = newPrimitiveType("nil") ) -func newFunctionType(name string, paramTypes []Type, returnType Type) Type { +func newFunctionType(name string, paramTypes []Type, returnType Type) functionType { return functionType{name: name, paramTypes: paramTypes, returnType: returnType} } @@ -103,7 +103,7 @@ func (t functionType) String() string { return name } -func newAliasType(name string, parent Type) aliasType { +func newAliasType(name string, parent Type) Type { return aliasType{name: name, parent: parent} } @@ -123,6 +123,62 @@ func (t aliasType) Equals(t2 Type) bool { return t.parent.Equals(t2) } +type classType struct { + name string + superClass Type + env *interpret.Environment +} + +func (c classType) String() string { + return c.name +} + +func (c classType) Equals(t Type) bool { + if c == t { + return true + } + + if alias, ok := t.(aliasType); ok { + return alias.Equals(c) + } + + if c.superClass != nil { + return c.superClass.Equals(t) + } + + return false +} + +func (c classType) getField(name string) (Type, error) { + fieldType, err := c.env.Get(name) + if err != nil { + return nil, err + } + return fieldType.(Type), nil +} + +func (c classType) getConstructor() (functionType, error) { + constructor, err := c.getField("init") + if err == interpret.ErrUndefined { + return newFunctionType("", []Type{}, c), nil + } else if err != nil { + return functionType{}, err + } + return constructor.(functionType), nil +} + +func newClassType(name string, superClass Type) classType { + var enclosingEnv *interpret.Environment + if superClassAsClassType, ok := superClass.(classType); ok { + enclosingEnv = superClassAsClassType.env + } + return classType{ + name: name, + superClass: superClass, + env: &interpret.Environment{Enclosing: enclosingEnv}, + } +} + func NewTypeChecker(interpreter *interpret.Interpreter) *TypeChecker { globals := interpret.Environment{} globals.Define("clock", newFunctionType("", []Type{}, TypeNumber)) @@ -143,6 +199,7 @@ type TypeChecker struct { declaredFnReturnType Type inferredFnReturnType Type types map[string]Type + enclosingClass classType } func (c *TypeChecker) VisitTypeDeclStmt(stmt ast.TypeDeclStmt) interface{} { @@ -179,8 +236,34 @@ func (c *TypeChecker) checkBlock(stmts []ast.Stmt, env interpret.Environment) { } func (c *TypeChecker) VisitClassStmt(stmt ast.ClassStmt) interface{} { - // TODO implement me - panic("implement me") + var superclassType Type + if stmt.Superclass != nil { + superclassType = c.check(stmt.Superclass) + } + + classType := newClassType(stmt.Name.Lexeme, superclassType) + + c.types[stmt.Name.Lexeme] = classType + c.env.Define(stmt.Name.Lexeme, classType) + + previous := c.env + previousClass := c.enclosingClass + defer func() { + c.env = previous + c.enclosingClass = previousClass + }() + + c.env = classType.env + c.enclosingClass = classType + + // What should the actual type of classType be? Shouldn't it be a function + // callable by its initializer's args to return the class type? + + for _, method := range stmt.Methods { + c.checkStmt(method) + } + + return nil } func (c *TypeChecker) VisitExpressionStmt(stmt ast.ExpressionStmt) interface{} { @@ -312,23 +395,31 @@ func (c *TypeChecker) VisitBinaryExpr(expr ast.BinaryExpr) interface{} { } func (c *TypeChecker) VisitCallExpr(expr ast.CallExpr) interface{} { - calleeType := c.check(expr.Callee) - - fnType, ok := calleeType.(functionType) - if !ok { - panic(TypeError{message: "Cannot call a value that's not a function"}) + switch calleeType := c.check(expr.Callee).(type) { + case functionType: + return c.checkFunctionCall(calleeType, expr) + case classType: + constructor, err := calleeType.getConstructor() + if err != nil { + panic(TypeError{message: err.Error()}) + } + return c.checkFunctionCall(constructor, expr) + default: + panic(TypeError{message: "Cannot call a value that is not a function or method"}) } +} - if len(fnType.paramTypes) != len(expr.Arguments) { - panic(TypeError{message: fmt.Sprintf("function of type %s expects %d arguments, got %d", fnType, len(fnType.paramTypes), len(expr.Arguments))}) +func (c *TypeChecker) checkFunctionCall(calleeType functionType, expr ast.CallExpr) Type { + if len(calleeType.paramTypes) != len(expr.Arguments) { + panic(TypeError{message: fmt.Sprintf("function of type %s expects %d arguments, got %d", calleeType, len(calleeType.paramTypes), len(expr.Arguments))}) } for i, arg := range expr.Arguments { argType := c.check(arg) - c.expect(argType, fnType.paramTypes[i], arg, expr) + c.expect(argType, calleeType.paramTypes[i], arg, expr) } - return fnType.returnType + return calleeType.returnType } func (c *TypeChecker) VisitFunctionExpr(expr ast.FunctionExpr) interface{} { @@ -402,8 +493,11 @@ func (c *TypeChecker) VisitLogicalExpr(expr ast.LogicalExpr) interface{} { } func (c *TypeChecker) VisitSetExpr(expr ast.SetExpr) interface{} { - // TODO implement me - panic("implement me") + // TODO: check that the object is an instance of a class + // and that object[name] has the same type as value + c.check(expr.Object) + + return c.check(expr.Value) } func (c *TypeChecker) VisitSuperExpr(expr ast.SuperExpr) interface{} { @@ -412,8 +506,7 @@ func (c *TypeChecker) VisitSuperExpr(expr ast.SuperExpr) interface{} { } func (c *TypeChecker) VisitThisExpr(expr ast.ThisExpr) interface{} { - // TODO implement me - panic("implement me") + return c.enclosingClass } func (c *TypeChecker) VisitTernaryExpr(expr ast.TernaryExpr) interface{} { @@ -445,7 +538,7 @@ func (c *TypeChecker) lookupType(name ast.Token, expr ast.Expr) (Type, error) { if distance, ok := c.interpreter.GetLocalDistance(expr); ok { return c.env.GetAt(distance, name.Lexeme).(Type), nil } - nameType, err := c.globals.Get(name) + nameType, err := c.globals.Get(name.Lexeme) if err != nil { return nil, err }