Skip to content

chore: add location option to loading scripts #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,20 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}
opt := complete(opts...)

var locationPath, locationName string
if opt.Location != "" {
locationPath = path.Dir(opt.Location)
locationName = path.Base(opt.Location)
}

prg := types.Program{
ToolSet: types.ToolSet{},
}
tools, err := readTool(ctx, opt.Cache, &prg, &source{
Content: []byte(content),
Location: "inline",
Path: locationPath,
Name: locationName,
Location: opt.Location,
}, subToolName)
if err != nil {
return types.Program{}, err
Expand All @@ -388,12 +396,18 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}

type Options struct {
Cache *cache.Client
Cache *cache.Client
Location string
}

func complete(opts ...Options) (result Options) {
for _, opt := range opts {
result.Cache = types.FirstSet(opt.Cache, result.Cache)
result.Location = types.FirstSet(opt.Location, result.Location)
}

if result.Location == "" {
result.Location = "inline"
}

return
Expand Down
38 changes: 25 additions & 13 deletions pkg/loader/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
req.Header.Set("Authorization", "Bearer "+bearerToken)
}

data, err := getWithDefaults(req)
data, defaulted, err := getWithDefaults(req)
if err != nil {
return nil, false, fmt.Errorf("error loading %s: %v", url, err)
}

if defaulted != "" {
pathString = url
name = defaulted
if repo != nil {
repo.Path = path.Join(repo.Path, repo.Name)
repo.Name = defaulted
}
}

log.Debugf("opened %s", url)

result := &source{
Expand All @@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
return result, true, nil
}

func getWithDefaults(req *http.Request) ([]byte, error) {
func getWithDefaults(req *http.Request) ([]byte, string, error) {
originalPath := req.URL.Path

// First, try to get the original path as is. It might be an OpenAPI definition.
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusOK {
if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 {
return toolBytes, nil
}
toolBytes, err := io.ReadAll(resp.Body)
return toolBytes, "", err
}

base := path.Base(originalPath)
if strings.Contains(base, ".") {
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

for i, def := range types.DefaultFiles {
base := path.Base(originalPath)
if !strings.Contains(base, ".") {
req.URL.Path = path.Join(originalPath, def)
}
req.URL.Path = path.Join(originalPath, def)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

Expand All @@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) {
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

return io.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
return data, def, err
}

panic("unreachable")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
logger.Debugf("executing tool: %+v", reqObject)
var (
def fmt.Stringer = &reqObject.ToolDefs
programLoader loaderFunc = loader.ProgramFromSource
programLoader = loaderWithLocation(loader.ProgramFromSource, reqObject.Location)
)
if reqObject.Content != "" {
def = &reqObject.content
Expand Down
8 changes: 8 additions & 0 deletions pkg/sdkserver/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ import (

type loaderFunc func(context.Context, string, string, ...loader.Options) (types.Program, error)

func loaderWithLocation(f loaderFunc, loc string) loaderFunc {
return func(ctx context.Context, s string, s2 string, options ...loader.Options) (types.Program, error) {
return f(ctx, s, s2, append(options, loader.Options{
Location: loc,
})...)
}
}

func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) {
g, err := gptscript.New(ctx, s.gptscriptOpts, opts)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/sdkserver/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type toolOrFileRequest struct {
CredentialContext string `json:"credentialContext"`
CredentialOverrides []string `json:"credentialOverrides"`
Confirm bool `json:"confirm"`
Location string `json:"location,omitempty"`
}

type content struct {
Expand Down