Skip to content

Commit

Permalink
Make separate field for class initializer in ast node
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Jul 11, 2022
1 parent 8a5455c commit 0973b12
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 43 deletions.
6 changes: 2 additions & 4 deletions ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,11 @@ type SetExpr struct {
}

func (b SetExpr) StartLine() int {
// TODO implement me
panic("implement me")
return b.Object.StartLine()
}

func (b SetExpr) EndLine() int {
// TODO implement me
panic("implement me")
return b.Value.EndLine()
}

func (b SetExpr) Accept(visitor ExprVisitor) interface{} {
Expand Down
1 change: 1 addition & 0 deletions ast/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Field struct {
type ClassStmt struct {
Name Token
Superclass *VariableExpr
Init *FunctionStmt
Methods []FunctionStmt
Fields []Field
LineStart int
Expand Down
45 changes: 28 additions & 17 deletions interpret/class.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ import (
)

type class struct {
name string
methods map[string]function
superclass *class
fields []ast.Field
env *env.Environment
name string
initializer *function
methods map[string]function
superclass *class
fields []ast.Field
env *env.Environment
}

// arity returns the arity of the class's constructor
func (c class) arity() int {
if initializer, ok := c.findMethod("init"); ok {
return initializer.arity()
if c.initializer == nil {
return 0
}
return 0
return c.initializer.arity()
}

// call-s the class's constructor and returns the new instance
Expand All @@ -37,30 +38,39 @@ func (c class) call(interpreter *Interpreter, arguments []interface{}) interface
interpreter.environment = previous

// initialize
if initializer, ok := c.findMethod("init"); ok {
initializer.bind(in).call(interpreter, arguments)
if c.initializer != nil {
c.initializer.bind(in).call(interpreter, arguments)
}

return in
}

// Get returns value of the static method with the given name
func (c class) Get(in *Interpreter, name ast.Token) (interface{}, error) {
if method, ok := c.findMethod(name.Lexeme); ok {
return method, nil
method := c.findMethod(name.Lexeme)
if method == nil {
return nil, runtimeError{token: name, msg: fmt.Sprintf("Undefined property '%s'.", name.Lexeme)}
}
return nil, runtimeError{token: name, msg: fmt.Sprintf("Undefined property '%s'.", name.Lexeme)}
return method, nil
}

func (c class) findMethod(name string) (function, bool) {
// todo: should probably return a pointer
func (c class) findMethod(name string) *function {
if name == "init" {
if c.initializer == nil {
return nil
}
return c.initializer
}

method, ok := c.methods[name]
if ok {
return method, true
return &method
}
if c.superclass != nil {
return c.superclass.findMethod(name)
}
return function{}, false
return nil
}

func (c class) String() string {
Expand All @@ -84,7 +94,8 @@ func (i *instance) Get(in *Interpreter, name ast.Token) (interface{}, error) {
return val, nil
}

if method, ok := i.class.findMethod(name.Lexeme); ok {
method := i.class.findMethod(name.Lexeme)
if method != nil {
// if the method is a getter, call and return its value
if method.isGetter {
return method.bind(i).call(in, nil), nil
Expand Down
25 changes: 21 additions & 4 deletions interpret/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,35 @@ func (in *Interpreter) VisitClassStmt(stmt ast.ClassStmt) interface{} {
in.environment.Define("super", superclass)
}

var initializer *function
if stmt.Init != nil {
initializer = &function{
declaration: *stmt.Init,
closure: in.environment,
isInitializer: true,
isGetter: false,
}
}

methods := make(map[string]function, len(stmt.Methods))
for _, method := range stmt.Methods {
fn := function{
declaration: method,
closure: in.environment,
isInitializer: method.Name.Lexeme == "init",
isInitializer: false,
isGetter: method.Params == nil, // is this the best way to know it's a getter?
}
methods[method.Name.Lexeme] = fn
}

class := class{name: stmt.Name.Lexeme, methods: methods, superclass: superclass, fields: stmt.Fields, env: in.environment}
class := class{
name: stmt.Name.Lexeme,
methods: methods,
superclass: superclass,
fields: stmt.Fields,
env: in.environment,
initializer: initializer,
}

if superclass != nil {
in.environment = in.environment.Enclosing
Expand Down Expand Up @@ -399,8 +416,8 @@ func (in *Interpreter) VisitSuperExpr(expr ast.SuperExpr) interface{} {
distance, _ := in.GetLocalDistance(expr)
superclass := in.environment.GetAt(distance, "super").(*class)
object := in.environment.GetAt(distance-1, "this").(*instance)
method, ok := superclass.findMethod(expr.Method.Lexeme)
if !ok {
method := superclass.findMethod(expr.Method.Lexeme)
if method == nil {
in.error(expr.Method, fmt.Sprintf("Undefined property '%s'.", expr.Method.Lexeme))
}
return method.bind(object)
Expand Down
8 changes: 7 additions & 1 deletion parse/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ func (p *Parser) classDeclaration() ast.Stmt {

p.consume(ast.TokenLeftBrace, "Expect '{' before class body.")

var initMethod *ast.FunctionStmt
methods := make([]ast.FunctionStmt, 0)
fields := make([]ast.Field, 0)
for !p.check(ast.TokenRightBrace) && !p.isAtEnd() {
Expand All @@ -139,7 +140,11 @@ func (p *Parser) classDeclaration() ast.Stmt {
fields = append(fields, field)
} else {
method := p.function("method")
methods = append(methods, method)
if method.Name.Lexeme == "init" {
initMethod = &method
} else {
methods = append(methods, method)
}
}
}

Expand All @@ -148,6 +153,7 @@ func (p *Parser) classDeclaration() ast.Stmt {
return ast.ClassStmt{
Name: name,
Fields: fields,
Init: initMethod,
Methods: methods,
Superclass: superclass,
LineStart: lineStart,
Expand Down
11 changes: 5 additions & 6 deletions resolve/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,12 @@ func (r *Resolver) VisitClassStmt(stmt ast.ClassStmt) interface{} {

r.scopes.peek().set("this")

for _, method := range stmt.Methods {
declaration := functionTypeMethod
if method.Name.Lexeme == "init" {
declaration = functionTypeInitializer
}
if stmt.Init != nil {
r.resolveFunction(*stmt.Init, functionTypeInitializer)
}

r.resolveFunction(method, declaration)
for _, method := range stmt.Methods {
r.resolveFunction(method, functionTypeMethod)
}

r.endScope()
Expand Down
1 change: 1 addition & 0 deletions typechecker/testdata/class-assign-field.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
type error on line 2: only instances have fields
3 changes: 3 additions & 0 deletions typechecker/testdata/class-assign-field.input
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Class {}

Class.hello = "world";
36 changes: 26 additions & 10 deletions typechecker/typechecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ type TypeChecker struct {

func (c *TypeChecker) VisitTypeDeclStmt(stmt ast.TypeDeclStmt) interface{} {
if c.env.Has(stmt.Name.Lexeme) {
c.error(fmt.Sprintf("Type with name %s is already defined.", stmt.Name.Lexeme))
c.errorNoLine(fmt.Sprintf("Type with name %s is already defined.", stmt.Name.Lexeme))
}

baseType := c.typeFromParsed(stmt.Base)
if baseType == nil {
c.error(fmt.Sprintf("Type %v is not defined", stmt.Base))
c.errorNoLine(fmt.Sprintf("Type %v is not defined", stmt.Base))
}

alias := newAliasType(stmt.Name.Lexeme, baseType)
Expand Down Expand Up @@ -74,7 +74,6 @@ func (c *TypeChecker) VisitClassStmt(stmt ast.ClassStmt) interface{} {
classType := newClassType(stmt.Name.Lexeme, superclassType)

c.types[stmt.Name.Lexeme] = classType
c.env.Define(stmt.Name.Lexeme, classType)

previous := c.env
previousClass := c.enclosingClass
Expand All @@ -99,8 +98,21 @@ func (c *TypeChecker) VisitClassStmt(stmt ast.ClassStmt) interface{} {
classType.properties.Define(field.Name.Lexeme, fieldType)
}

// Get signature of the init method. We do this before checking the rest of the
// methods, just in case the methods also reference the class in their body
var initMethodParams []loxType
if stmt.Init != nil {
initMethodParams = make([]loxType, len(stmt.Init.Params))
for i, param := range stmt.Init.Params {
initMethodParams[i] = c.typeFromParsed(param.Type)
}
}

previous.Define(stmt.Name.Lexeme, newFunctionType("", initMethodParams, classType))

for _, method := range stmt.Methods {
c.checkMethod(method)
// TODO: shouldn't this be saved somewhere?
}

return nil
Expand Down Expand Up @@ -352,12 +364,12 @@ func (c *TypeChecker) VisitGetExpr(expr ast.GetExpr) interface{} {

objectClassType, ok := object.(classType)
if !ok {
c.error("object must be an instance of a class")
c.errorNoLine("object must be an instance of a class")
}

field, err := objectClassType.getField(expr.Name.Lexeme)
if err != nil {
c.error(err.Error())
c.errorNoLine(err.Error())
}

return field
Expand Down Expand Up @@ -389,12 +401,12 @@ func (c *TypeChecker) VisitSetExpr(expr ast.SetExpr) interface{} {

objectAsClassType, ok := object.(classType)
if !ok {
panic(typeError{message: "cannot set properties on a non-class type"})
c.error(expr.StartLine(), "only instances have fields")
}

property, err := objectAsClassType.properties.Get(expr.Name.Lexeme)
if err != nil {
c.error("property does not exist on class")
c.errorNoLine("property does not exist on class")
}

valueType := c.check(expr.Value)
Expand Down Expand Up @@ -477,7 +489,7 @@ func (c *TypeChecker) typeFromParsed(parsedType ast.Type) loxType {

func (c *TypeChecker) expect(actual loxType, expected loxType, value ast.Expr, expr ast.Expr) loxType {
if !actual.equals(expected) {
c.error(fmt.Sprintf("error on line %d: expected '%s' type, but got '%s'", value.StartLine()+1, expected.String(), actual.String()))
c.errorNoLine(fmt.Sprintf("error on line %d: expected '%s' type, but got '%s'", value.StartLine()+1, expected.String(), actual.String()))
}
return actual
}
Expand All @@ -497,13 +509,17 @@ func (c *TypeChecker) expectOperatorType(inputType loxType, allowedTypes []loxTy
return
}
}
c.error(fmt.Sprintf("unexpected type: %v in %v, allowed: %v", inputType, expr, allowedTypes))
c.errorNoLine(fmt.Sprintf("unexpected type: %v in %v, allowed: %v", inputType, expr, allowedTypes))
}

func (c *TypeChecker) error(message string) {
func (c *TypeChecker) errorNoLine(message string) {
panic(typeError{message: message})
}

func (c *TypeChecker) error(line int, message string) {
panic(typeError{line: line, message: message})
}

func (c *TypeChecker) isBooleanBinary(operator ast.Token) bool {
switch operator.TokenType {
case ast.TokenBangEqual, ast.TokenEqualEqual,
Expand Down
10 changes: 9 additions & 1 deletion typechecker/types.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
package typechecker

import "github.com/chidiwilliams/glox/env"
import (
"fmt"

"github.com/chidiwilliams/glox/env"
)

type typeError struct {
message string
line int
}

func (e typeError) Error() string {
if e.line > 0 {
return fmt.Sprintf("type error on line %d: %s", e.line, e.message)
}
return e.message
}

Expand Down

0 comments on commit 0973b12

Please sign in to comment.