Skip to content

Commit

Permalink
Add type checking for class statements
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Jul 10, 2022
1 parent 31565bf commit 1a755b7
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 25 deletions.
15 changes: 10 additions & 5 deletions interpret/env.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package interpret

import (
"fmt"
"errors"

"github.com/chidiwilliams/glox/ast"
)

// TODO: move this to an env package

// ErrUndefined is returned when retrieving or assigning to an undefined variable
var ErrUndefined = errors.New("undefined variable")

type Environment struct {
Enclosing *Environment
values map[string]interface{}
Expand All @@ -29,14 +34,14 @@ func (e *Environment) Has(name string) bool {
return false
}

func (e *Environment) Get(name ast.Token) (interface{}, error) {
if val, ok := e.values[name.Lexeme]; ok {
func (e *Environment) Get(name string) (interface{}, error) {
if val, ok := e.values[name]; ok {
return val, nil
}
if e.Enclosing != nil {
return e.Enclosing.Get(name)
}
return nil, runtimeError{name, fmt.Sprintf("Undefined variable '%s'", name.Lexeme)}
return nil, ErrUndefined
}

func (e *Environment) assign(name ast.Token, value interface{}) error {
Expand All @@ -47,7 +52,7 @@ func (e *Environment) assign(name ast.Token, value interface{}) error {
if e.Enclosing != nil {
return e.Enclosing.assign(name, value)
}
return runtimeError{name, fmt.Sprintf("Undefined variable '%s'", name.Lexeme)}
return ErrUndefined
}

func (e *Environment) GetAt(distance int, name string) interface{} {
Expand Down
2 changes: 1 addition & 1 deletion interpret/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func (in *Interpreter) lookupVariable(name ast.Token, expr ast.Expr) (interface{
if distance, ok := in.GetLocalDistance(expr); ok {
return in.environment.GetAt(distance, name.Lexeme), nil
}
return in.globals.Get(name)
return in.globals.Get(name.Lexeme)
}

func (in *Interpreter) VisitBinaryExpr(expr ast.BinaryExpr) interface{} {
Expand Down
2 changes: 1 addition & 1 deletion parse/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ func (p *Parser) error(token ast.Token, message string) {
where = " at '" + token.Lexeme + "'"
}

err := parseError{msg: fmt.Sprintf("[line %d] Error%s: %s\n", token.Line, where, message)}
err := parseError{msg: fmt.Sprintf("[line %d] Error%s: %s\n", token.Line+1, where, message)}
_, _ = p.stdErr.Write([]byte(err.Error()))
panic(err)
}
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions typechecker/testdata/class.input
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class Greeter {
greet(): string {
return "hello";
}
}

var greeter: Greeter = Greeter();
greeter;
129 changes: 111 additions & 18 deletions typechecker/typechecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var (
TypeNil = newPrimitiveType("nil")
)

func newFunctionType(name string, paramTypes []Type, returnType Type) Type {
func newFunctionType(name string, paramTypes []Type, returnType Type) functionType {
return functionType{name: name, paramTypes: paramTypes, returnType: returnType}
}

Expand Down Expand Up @@ -103,7 +103,7 @@ func (t functionType) String() string {
return name
}

func newAliasType(name string, parent Type) aliasType {
func newAliasType(name string, parent Type) Type {
return aliasType{name: name, parent: parent}
}

Expand All @@ -123,6 +123,62 @@ func (t aliasType) Equals(t2 Type) bool {
return t.parent.Equals(t2)
}

type classType struct {
name string
superClass Type
env *interpret.Environment
}

func (c classType) String() string {
return c.name
}

func (c classType) Equals(t Type) bool {
if c == t {
return true
}

if alias, ok := t.(aliasType); ok {
return alias.Equals(c)
}

if c.superClass != nil {
return c.superClass.Equals(t)
}

return false
}

func (c classType) getField(name string) (Type, error) {
fieldType, err := c.env.Get(name)
if err != nil {
return nil, err
}
return fieldType.(Type), nil
}

func (c classType) getConstructor() (functionType, error) {
constructor, err := c.getField("init")
if err == interpret.ErrUndefined {
return newFunctionType("", []Type{}, c), nil
} else if err != nil {
return functionType{}, err
}
return constructor.(functionType), nil
}

func newClassType(name string, superClass Type) classType {
var enclosingEnv *interpret.Environment
if superClassAsClassType, ok := superClass.(classType); ok {
enclosingEnv = superClassAsClassType.env
}
return classType{
name: name,
superClass: superClass,
env: &interpret.Environment{Enclosing: enclosingEnv},
}
}

func NewTypeChecker(interpreter *interpret.Interpreter) *TypeChecker {
globals := interpret.Environment{}
globals.Define("clock", newFunctionType("", []Type{}, TypeNumber))
Expand All @@ -143,6 +199,7 @@ type TypeChecker struct {
declaredFnReturnType Type
inferredFnReturnType Type
types map[string]Type
enclosingClass classType
}

func (c *TypeChecker) VisitTypeDeclStmt(stmt ast.TypeDeclStmt) interface{} {
Expand Down Expand Up @@ -179,8 +236,34 @@ func (c *TypeChecker) checkBlock(stmts []ast.Stmt, env interpret.Environment) {
}

func (c *TypeChecker) VisitClassStmt(stmt ast.ClassStmt) interface{} {
// TODO implement me
panic("implement me")
var superclassType Type
if stmt.Superclass != nil {
superclassType = c.check(stmt.Superclass)
}

classType := newClassType(stmt.Name.Lexeme, superclassType)

c.types[stmt.Name.Lexeme] = classType
c.env.Define(stmt.Name.Lexeme, classType)

previous := c.env
previousClass := c.enclosingClass
defer func() {
c.env = previous
c.enclosingClass = previousClass
}()

c.env = classType.env
c.enclosingClass = classType

// What should the actual type of classType be? Shouldn't it be a function
// callable by its initializer's args to return the class type?

for _, method := range stmt.Methods {
c.checkStmt(method)
}

return nil
}

func (c *TypeChecker) VisitExpressionStmt(stmt ast.ExpressionStmt) interface{} {
Expand Down Expand Up @@ -312,23 +395,31 @@ func (c *TypeChecker) VisitBinaryExpr(expr ast.BinaryExpr) interface{} {
}

func (c *TypeChecker) VisitCallExpr(expr ast.CallExpr) interface{} {
calleeType := c.check(expr.Callee)

fnType, ok := calleeType.(functionType)
if !ok {
panic(TypeError{message: "Cannot call a value that's not a function"})
switch calleeType := c.check(expr.Callee).(type) {
case functionType:
return c.checkFunctionCall(calleeType, expr)
case classType:
constructor, err := calleeType.getConstructor()
if err != nil {
panic(TypeError{message: err.Error()})
}
return c.checkFunctionCall(constructor, expr)
default:
panic(TypeError{message: "Cannot call a value that is not a function or method"})
}
}

if len(fnType.paramTypes) != len(expr.Arguments) {
panic(TypeError{message: fmt.Sprintf("function of type %s expects %d arguments, got %d", fnType, len(fnType.paramTypes), len(expr.Arguments))})
func (c *TypeChecker) checkFunctionCall(calleeType functionType, expr ast.CallExpr) Type {
if len(calleeType.paramTypes) != len(expr.Arguments) {
panic(TypeError{message: fmt.Sprintf("function of type %s expects %d arguments, got %d", calleeType, len(calleeType.paramTypes), len(expr.Arguments))})
}

for i, arg := range expr.Arguments {
argType := c.check(arg)
c.expect(argType, fnType.paramTypes[i], arg, expr)
c.expect(argType, calleeType.paramTypes[i], arg, expr)
}

return fnType.returnType
return calleeType.returnType
}

func (c *TypeChecker) VisitFunctionExpr(expr ast.FunctionExpr) interface{} {
Expand Down Expand Up @@ -402,8 +493,11 @@ func (c *TypeChecker) VisitLogicalExpr(expr ast.LogicalExpr) interface{} {
}

func (c *TypeChecker) VisitSetExpr(expr ast.SetExpr) interface{} {
// TODO implement me
panic("implement me")
// TODO: check that the object is an instance of a class
// and that object[name] has the same type as value
c.check(expr.Object)

return c.check(expr.Value)
}

func (c *TypeChecker) VisitSuperExpr(expr ast.SuperExpr) interface{} {
Expand All @@ -412,8 +506,7 @@ func (c *TypeChecker) VisitSuperExpr(expr ast.SuperExpr) interface{} {
}

func (c *TypeChecker) VisitThisExpr(expr ast.ThisExpr) interface{} {
// TODO implement me
panic("implement me")
return c.enclosingClass
}

func (c *TypeChecker) VisitTernaryExpr(expr ast.TernaryExpr) interface{} {
Expand Down Expand Up @@ -445,7 +538,7 @@ func (c *TypeChecker) lookupType(name ast.Token, expr ast.Expr) (Type, error) {
if distance, ok := c.interpreter.GetLocalDistance(expr); ok {
return c.env.GetAt(distance, name.Lexeme).(Type), nil
}
nameType, err := c.globals.Get(name)
nameType, err := c.globals.Get(name.Lexeme)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 1a755b7

Please sign in to comment.