Skip to content

Commit cd41c33

Browse files
Implement type checker compatible Input() function (#196)
Convert Input from dataclass to function with overloads following Pydantic's pattern. Maintains exact same syntax for developers while providing full type checker compatibility.
1 parent 5a780b2 commit cd41c33

File tree

9 files changed

+1639
-8
lines changed

9 files changed

+1639
-8
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
package tests
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/replicate/cog-runtime/internal/server"
14+
)
15+
16+
func TestInputFunctionSchemaGeneration(t *testing.T) {
17+
t.Parallel()
18+
if *legacyCog {
19+
t.Skip("Input generation tests coglet specific implementations.")
20+
}
21+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
22+
procedureMode: false,
23+
explicitShutdown: false,
24+
uploadURL: "",
25+
module: "input_function",
26+
predictorClass: "Predictor",
27+
})
28+
29+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
30+
31+
resp, err := http.Get(runtimeServer.URL + "/openapi.json")
32+
require.NoError(t, err)
33+
defer resp.Body.Close()
34+
35+
body, err := io.ReadAll(resp.Body)
36+
require.NoError(t, err)
37+
38+
var schema map[string]any
39+
err = json.Unmarshal(body, &schema)
40+
require.NoError(t, err)
41+
42+
assert.Contains(t, schema, "components")
43+
44+
components := schema["components"].(map[string]any)
45+
assert.Contains(t, components, "schemas")
46+
47+
schemas := components["schemas"].(map[string]any)
48+
assert.Contains(t, schemas, "Input")
49+
50+
inputSchema := schemas["Input"].(map[string]any)
51+
assert.Equal(t, "object", inputSchema["type"])
52+
assert.Contains(t, inputSchema, "properties")
53+
assert.Contains(t, inputSchema, "required")
54+
55+
properties := inputSchema["properties"].(map[string]any)
56+
required := inputSchema["required"].([]any)
57+
58+
assert.Contains(t, properties, "message")
59+
assert.Contains(t, required, "message")
60+
messageField := properties["message"].(map[string]any)
61+
assert.Equal(t, "string", messageField["type"])
62+
assert.Equal(t, "Message to process", messageField["description"])
63+
64+
assert.Contains(t, properties, "repeat_count")
65+
assert.NotContains(t, required, "repeat_count")
66+
repeatField := properties["repeat_count"].(map[string]any)
67+
assert.Equal(t, "integer", repeatField["type"])
68+
assert.Equal(t, float64(1), repeatField["default"]) //nolint:testifylint // Checking absolute value not delta
69+
assert.Equal(t, float64(1), repeatField["minimum"]) //nolint:testifylint // Checking absolute value not delta
70+
assert.Equal(t, float64(10), repeatField["maximum"]) //nolint:testifylint // Checking absolute value not delta
71+
72+
assert.Contains(t, properties, "prefix")
73+
prefixField := properties["prefix"].(map[string]any)
74+
assert.Equal(t, "string", prefixField["type"])
75+
assert.Equal(t, "Result: ", prefixField["default"])
76+
assert.Equal(t, float64(1), prefixField["minLength"]) //nolint:testifylint // Checking absolute value not delta
77+
assert.Equal(t, float64(20), prefixField["maxLength"]) //nolint:testifylint // Checking absolute value not delta
78+
79+
assert.Contains(t, properties, "deprecated_option")
80+
deprecatedField := properties["deprecated_option"].(map[string]any)
81+
assert.Equal(t, true, deprecatedField["deprecated"])
82+
}
83+
84+
func TestInputFunctionBasicPrediction(t *testing.T) {
85+
t.Parallel()
86+
if *legacyCog {
87+
t.Skip("Input generation tests coglet specific implementations.")
88+
}
89+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
90+
procedureMode: false,
91+
explicitShutdown: false,
92+
uploadURL: "",
93+
module: "input_function",
94+
predictorClass: "Predictor",
95+
})
96+
97+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
98+
99+
input := map[string]any{"message": "hello world"}
100+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
101+
102+
resp, err := http.DefaultClient.Do(req)
103+
require.NoError(t, err)
104+
defer resp.Body.Close()
105+
assert.Equal(t, http.StatusOK, resp.StatusCode)
106+
107+
body, err := io.ReadAll(resp.Body)
108+
require.NoError(t, err)
109+
110+
var prediction server.PredictionResponse
111+
err = json.Unmarshal(body, &prediction)
112+
require.NoError(t, err)
113+
114+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
115+
assert.Equal(t, "Result: hello world", prediction.Output)
116+
}
117+
118+
func TestInputFunctionComplexPrediction(t *testing.T) {
119+
t.Parallel()
120+
if *legacyCog {
121+
t.Skip("Input generation tests coglet specific implementations.")
122+
}
123+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
124+
procedureMode: false,
125+
explicitShutdown: false,
126+
uploadURL: "",
127+
module: "input_function",
128+
predictorClass: "Predictor",
129+
})
130+
131+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
132+
133+
input := map[string]any{
134+
"message": "test message",
135+
"repeat_count": 2,
136+
"format_type": "uppercase",
137+
"prefix": "Output: ",
138+
"suffix": " [END]",
139+
"deprecated_option": "custom",
140+
}
141+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
142+
143+
resp, err := http.DefaultClient.Do(req)
144+
require.NoError(t, err)
145+
defer resp.Body.Close()
146+
assert.Equal(t, http.StatusOK, resp.StatusCode)
147+
148+
body, err := io.ReadAll(resp.Body)
149+
require.NoError(t, err)
150+
151+
var prediction server.PredictionResponse
152+
err = json.Unmarshal(body, &prediction)
153+
require.NoError(t, err)
154+
155+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
156+
assert.Equal(t, "Output: TEST MESSAGE TEST MESSAGE [END]", prediction.Output)
157+
}
158+
159+
func TestInputFunctionConstraintViolations(t *testing.T) {
160+
t.Parallel()
161+
if *legacyCog {
162+
t.Skip("Input generation tests coglet specific implementations.")
163+
}
164+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
165+
procedureMode: false,
166+
explicitShutdown: false,
167+
uploadURL: "",
168+
module: "input_function",
169+
predictorClass: "Predictor",
170+
})
171+
172+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
173+
174+
testCases := []struct {
175+
name string
176+
input map[string]any
177+
errorMsg string
178+
}{
179+
{
180+
name: "repeat_count too low",
181+
input: map[string]any{"message": "test", "repeat_count": 0},
182+
errorMsg: "fails constraint >= 1",
183+
},
184+
{
185+
name: "repeat_count too high",
186+
input: map[string]any{"message": "test", "repeat_count": 11},
187+
errorMsg: "fails constraint <= 10",
188+
},
189+
{
190+
name: "invalid format_type choice",
191+
input: map[string]any{"message": "test", "format_type": "invalid"},
192+
errorMsg: "does not match choices",
193+
},
194+
{
195+
name: "prefix too short",
196+
input: map[string]any{"message": "test", "prefix": ""},
197+
errorMsg: "fails constraint len() >= 1",
198+
},
199+
{
200+
name: "prefix too long",
201+
input: map[string]any{"message": "test", "prefix": strings.Repeat("x", 21)},
202+
errorMsg: "fails constraint len() <= 20",
203+
},
204+
}
205+
206+
for _, tc := range testCases {
207+
t.Run(tc.name, func(t *testing.T) {
208+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: tc.input})
209+
210+
resp, err := http.DefaultClient.Do(req)
211+
require.NoError(t, err)
212+
defer resp.Body.Close()
213+
214+
body, err := io.ReadAll(resp.Body)
215+
require.NoError(t, err)
216+
217+
var errorResp server.PredictionResponse
218+
err = json.Unmarshal(body, &errorResp)
219+
require.NoError(t, err)
220+
221+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
222+
assert.Contains(t, errorResp.Error, tc.errorMsg)
223+
})
224+
}
225+
}
226+
227+
func TestInputFunctionMissingRequired(t *testing.T) {
228+
t.Parallel()
229+
if *legacyCog {
230+
t.Skip("Input generation tests coglet specific implementations.")
231+
}
232+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
233+
procedureMode: false,
234+
explicitShutdown: false,
235+
uploadURL: "",
236+
module: "input_function",
237+
predictorClass: "Predictor",
238+
})
239+
240+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
241+
242+
input := map[string]any{"repeat_count": 2}
243+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
244+
245+
resp, err := http.DefaultClient.Do(req)
246+
require.NoError(t, err)
247+
defer resp.Body.Close()
248+
249+
body, err := io.ReadAll(resp.Body)
250+
require.NoError(t, err)
251+
252+
var errorResp server.PredictionResponse
253+
err = json.Unmarshal(body, &errorResp)
254+
require.NoError(t, err)
255+
256+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
257+
assert.Contains(t, errorResp.Error, "missing required input field: message")
258+
}
259+
260+
func TestInputFunctionSimple(t *testing.T) {
261+
t.Parallel()
262+
if *legacyCog {
263+
t.Skip("Input generation tests coglet specific implementations.")
264+
}
265+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
266+
procedureMode: false,
267+
explicitShutdown: false,
268+
uploadURL: "",
269+
module: "input_simple",
270+
predictorClass: "Predictor",
271+
})
272+
273+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
274+
275+
input := map[string]any{"message": "hello", "count": 3}
276+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
277+
278+
resp, err := http.DefaultClient.Do(req)
279+
require.NoError(t, err)
280+
defer resp.Body.Close()
281+
assert.Equal(t, http.StatusOK, resp.StatusCode)
282+
283+
body, err := io.ReadAll(resp.Body)
284+
require.NoError(t, err)
285+
286+
var prediction server.PredictionResponse
287+
err = json.Unmarshal(body, &prediction)
288+
require.NoError(t, err)
289+
290+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
291+
assert.Equal(t, "hellohellohello", prediction.Output)
292+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ classifiers = [
1111
'Programming Language :: Python :: 3.12',
1212
'Programming Language :: Python :: 3.13',
1313
]
14-
dependencies = []
14+
dependencies = ["typing_extensions>=4.15"]
1515

1616
[project.optional-dependencies]
1717
dev = [

0 commit comments

Comments
 (0)