Skip to content

Commit f2c8dec

Browse files
authored
Merge pull request #1 from openshieldai/run
Run
2 parents 194a03e + 08aae7b commit f2c8dec

File tree

3 files changed

+66
-69
lines changed

3 files changed

+66
-69
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
module github.com/sashabaranov/go-openai
1+
module github.com/openshieldai/go-openai
22

33
go 1.18

run.go

Lines changed: 64 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@ type Run struct {
2828
Metadata map[string]any `json:"metadata"`
2929
Usage Usage `json:"usage,omitempty"`
3030

31-
Temperature *float32 `json:"temperature,omitempty"`
32-
// The maximum number of prompt tokens that may be used over the course of the run.
33-
// If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'.
34-
MaxPromptTokens int `json:"max_prompt_tokens,omitempty"`
35-
// The maximum number of completion tokens that may be used over the course of the run.
36-
// If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'.
37-
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
38-
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
39-
TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"`
40-
4131
httpHeader
4232
}
4333

@@ -50,7 +40,6 @@ const (
5040
RunStatusCancelling RunStatus = "cancelling"
5141
RunStatusFailed RunStatus = "failed"
5242
RunStatusCompleted RunStatus = "completed"
53-
RunStatusIncomplete RunStatus = "incomplete"
5443
RunStatusExpired RunStatus = "expired"
5544
RunStatusCancelled RunStatus = "cancelled"
5645
)
@@ -89,53 +78,7 @@ type RunRequest struct {
8978
AdditionalInstructions string `json:"additional_instructions,omitempty"`
9079
Tools []Tool `json:"tools,omitempty"`
9180
Metadata map[string]any `json:"metadata,omitempty"`
92-
93-
// Sampling temperature between 0 and 2. Higher values like 0.8 are more random.
94-
// lower values are more focused and deterministic.
95-
Temperature *float32 `json:"temperature,omitempty"`
96-
TopP *float32 `json:"top_p,omitempty"`
97-
98-
// The maximum number of prompt tokens that may be used over the course of the run.
99-
// If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'.
100-
MaxPromptTokens int `json:"max_prompt_tokens,omitempty"`
101-
102-
// The maximum number of completion tokens that may be used over the course of the run.
103-
// If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'.
104-
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
105-
106-
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
107-
TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"`
108-
109-
// This can be either a string or a ToolChoice object.
110-
ToolChoice any `json:"tool_choice,omitempty"`
111-
// This can be either a string or a ResponseFormat object.
112-
ResponseFormat any `json:"response_format,omitempty"`
113-
}
114-
115-
// ThreadTruncationStrategy defines the truncation strategy to use for the thread.
116-
// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy.
117-
type ThreadTruncationStrategy struct {
118-
// default 'auto'.
119-
Type TruncationStrategy `json:"type,omitempty"`
120-
// this field should be set if the truncation strategy is set to LastMessages.
121-
LastMessages *int `json:"last_messages,omitempty"`
122-
}
123-
124-
// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant.
125-
type TruncationStrategy string
126-
127-
const (
128-
// TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model.
129-
TruncationStrategyAuto = TruncationStrategy("auto")
130-
// TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread.
131-
TruncationStrategyLastMessages = TruncationStrategy("last_messages")
132-
)
133-
134-
// ReponseFormat specifies the format the model must output.
135-
// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format.
136-
// Type can either be text or json_object.
137-
type ReponseFormat struct {
138-
Type string `json:"type"`
81+
Stream bool `json:"stream,omitempty"`
13982
}
14083

14184
type RunModifyRequest struct {
@@ -240,7 +183,8 @@ func (c *Client) CreateRun(
240183
http.MethodPost,
241184
c.fullURL(urlSuffix),
242185
withBody(request),
243-
withBetaAssistantVersion(c.config.AssistantVersion))
186+
withBetaAssistantVersion(c.config.AssistantVersion),
187+
)
244188
if err != nil {
245189
return
246190
}
@@ -249,6 +193,51 @@ func (c *Client) CreateRun(
249193
return
250194
}
251195

196+
type RunStreamResponseDelta struct {
197+
Role string `json:"role"`
198+
Content []MessageContent `json:"content"`
199+
FileIDs []string `json:"file_ids"`
200+
}
201+
202+
type RunStreamResponse struct {
203+
ID string `json:"id"`
204+
Object string `json:"object"`
205+
Delta RunStreamResponseDelta `json:"delta"`
206+
}
207+
208+
type RunStream struct {
209+
*streamReader[RunStreamResponse]
210+
}
211+
212+
// CreateRunStream creates a new run with streaming support.
213+
func (c *Client) CreateRunStream(
214+
ctx context.Context,
215+
threadID string,
216+
request RunRequest,
217+
) (stream *RunStream, err error) {
218+
urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID)
219+
request.Stream = true
220+
req, err := c.newRequest(
221+
ctx,
222+
http.MethodPost,
223+
c.fullURL(urlSuffix),
224+
withBody(request),
225+
withBetaAssistantVersion(c.config.AssistantVersion),
226+
)
227+
if err != nil {
228+
return
229+
}
230+
231+
resp, err := sendRequestStream[RunStreamResponse](c, req)
232+
if err != nil {
233+
return
234+
}
235+
stream = &RunStream{
236+
streamReader: resp,
237+
}
238+
return
239+
}
240+
252241
// RetrieveRun retrieves a run.
253242
func (c *Client) RetrieveRun(
254243
ctx context.Context,
@@ -260,7 +249,8 @@ func (c *Client) RetrieveRun(
260249
ctx,
261250
http.MethodGet,
262251
c.fullURL(urlSuffix),
263-
withBetaAssistantVersion(c.config.AssistantVersion))
252+
withBetaAssistantVersion(c.config.AssistantVersion),
253+
)
264254
if err != nil {
265255
return
266256
}
@@ -282,7 +272,8 @@ func (c *Client) ModifyRun(
282272
http.MethodPost,
283273
c.fullURL(urlSuffix),
284274
withBody(request),
285-
withBetaAssistantVersion(c.config.AssistantVersion))
275+
withBetaAssistantVersion(c.config.AssistantVersion),
276+
)
286277
if err != nil {
287278
return
288279
}
@@ -321,7 +312,8 @@ func (c *Client) ListRuns(
321312
ctx,
322313
http.MethodGet,
323314
c.fullURL(urlSuffix),
324-
withBetaAssistantVersion(c.config.AssistantVersion))
315+
withBetaAssistantVersion(c.config.AssistantVersion),
316+
)
325317
if err != nil {
326318
return
327319
}
@@ -342,7 +334,8 @@ func (c *Client) SubmitToolOutputs(
342334
http.MethodPost,
343335
c.fullURL(urlSuffix),
344336
withBody(request),
345-
withBetaAssistantVersion(c.config.AssistantVersion))
337+
withBetaAssistantVersion(c.config.AssistantVersion),
338+
)
346339
if err != nil {
347340
return
348341
}
@@ -361,7 +354,8 @@ func (c *Client) CancelRun(
361354
ctx,
362355
http.MethodPost,
363356
c.fullURL(urlSuffix),
364-
withBetaAssistantVersion(c.config.AssistantVersion))
357+
withBetaAssistantVersion(c.config.AssistantVersion),
358+
)
365359
if err != nil {
366360
return
367361
}
@@ -380,7 +374,8 @@ func (c *Client) CreateThreadAndRun(
380374
http.MethodPost,
381375
c.fullURL(urlSuffix),
382376
withBody(request),
383-
withBetaAssistantVersion(c.config.AssistantVersion))
377+
withBetaAssistantVersion(c.config.AssistantVersion),
378+
)
384379
if err != nil {
385380
return
386381
}
@@ -401,7 +396,8 @@ func (c *Client) RetrieveRunStep(
401396
ctx,
402397
http.MethodGet,
403398
c.fullURL(urlSuffix),
404-
withBetaAssistantVersion(c.config.AssistantVersion))
399+
withBetaAssistantVersion(c.config.AssistantVersion),
400+
)
405401
if err != nil {
406402
return
407403
}
@@ -441,7 +437,8 @@ func (c *Client) ListRunSteps(
441437
ctx,
442438
http.MethodGet,
443439
c.fullURL(urlSuffix),
444-
withBetaAssistantVersion(c.config.AssistantVersion))
440+
withBetaAssistantVersion(c.config.AssistantVersion),
441+
)
445442
if err != nil {
446443
return
447444
}

stream_reader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ var (
1616
)
1717

1818
type streamable interface {
19-
ChatCompletionStreamResponse | CompletionResponse
19+
ChatCompletionStreamResponse | CompletionResponse | RunStreamResponse
2020
}
2121

2222
type streamReader[T streamable] struct {

0 commit comments

Comments
 (0)