diff --git a/packages/delisp-core/__tests__/infer.ts b/packages/delisp-core/__tests__/infer.ts index afe5b08a..1f254454 100644 --- a/packages/delisp-core/__tests__/infer.ts +++ b/packages/delisp-core/__tests__/infer.ts @@ -225,6 +225,24 @@ describe("Type inference", () => { }); }); + describe("Primitive functions", () => { + it("map with multiple values", () => { + expect(typeOf("(map (lambda (x) (values x x)) [1 2 3 4])")).toEqual( + "[number]" + ); + }); + it("filter with multiple values", () => { + expect( + typeOf("(filter (lambda (x) (values true 5)) [-2 -1 0 1 2])") + ).toEqual("[number]"); + }); + it("fold with multiple values", () => { + expect( + typeOf(`(fold (lambda (a b) (values 50 (+ a b))) [1 2 3 4] 0)`) + ).toEqual("number"); + }); + }); + describe("Type annotations", () => { it("user-specified variables can specialize inferred types", () => { expect(typeOf("(the [number] [])")).toBe("[number]"); diff --git a/packages/delisp-core/src/compiler/inline-primitives.ts b/packages/delisp-core/src/compiler/inline-primitives.ts index 1c733281..f6dacbae 100644 --- a/packages/delisp-core/src/compiler/inline-primitives.ts +++ b/packages/delisp-core/src/compiler/inline-primitives.ts @@ -174,13 +174,17 @@ defineInlinePrimitive("*", "(-> number number _ number)", args => { }; }); -defineInlinePrimitive("map", "(-> (-> a e b) [a] e [b])", ([fn, vec]) => { - return methodCall(vec, "map", [primitiveCall("bindPrimaryValue", fn)]); -}); +defineInlinePrimitive( + "map", + "(-> (-> a e (values b | _)) [a] e [b])", + ([fn, vec]) => { + return methodCall(vec, "map", [primitiveCall("bindPrimaryValue", fn)]); + } +); defineInlinePrimitive( "filter", - "(-> (-> a _ boolean) [a] _ [a])", + "(-> (-> a _ (values boolean | _)) [a] _ [a])", ([predicate, vec]) => { return methodCall(vec, "filter", [ primitiveCall("bindPrimaryValue", predicate) @@ -190,7 +194,7 @@ defineInlinePrimitive( defineInlinePrimitive( "fold", - "(-> (-> b a _ b) [a] b _ b)", + "(-> (-> b a _ (values b | _)) [a] b _ b)", ([fn, vec, init]) => { return methodCall(vec, "reduce", [ primitiveCall("bindPrimaryValue", fn),