-
Notifications
You must be signed in to change notification settings - Fork 18
/
downloader.go
157 lines (139 loc) · 4.26 KB
/
downloader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
//go:build !NODOWNLOAD
package hugot
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"path"
"path/filepath"
"strings"
"time"
hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader"
)
// DownloadOptions is a struct of options that can be passed to DownloadModel
type DownloadOptions struct {
AuthToken string
SkipSha bool
Branch string
MaxRetries int
RetryInterval int
ConcurrentConnections int
Verbose bool
}
// NewDownloadOptions creates new DownloadOptions struct with default values.
// Override the values to specify different download options.
func NewDownloadOptions() DownloadOptions {
d := DownloadOptions{}
d.Branch = "main"
d.MaxRetries = 5
d.RetryInterval = 5
d.ConcurrentConnections = 5
return d
}
// DownloadModel can be used to download a model directly from huggingface. Before the model is downloaded,
// validation occurs to ensure there is an .onnx and tokenizers.json file. Hugot only works with onnx models.
func DownloadModel(modelName string, destination string, options DownloadOptions) (string, error) {
// make sure it's an onnx model with tokenizer
err := validateDownloadHfModel(modelName, options.Branch, options.AuthToken)
if err != nil {
return "", err
}
// replicates code in hf downloader
modelP := modelName
if strings.Contains(modelP, ":") {
modelP = strings.Split(modelName, ":")[0]
}
modelPath := path.Join(destination, strings.Replace(modelP, "/", "_", -1))
for i := 0; i < options.MaxRetries; i++ {
if err := hfd.DownloadModel(modelName, false, options.SkipSha, false, destination, options.Branch, options.ConcurrentConnections, options.AuthToken, !options.Verbose); err != nil {
if options.Verbose {
fmt.Printf("Warning: attempt %d / %d failed, error: %s\n", i+1, options.MaxRetries, err)
}
time.Sleep(time.Duration(options.RetryInterval) * time.Second)
continue
}
if options.Verbose {
fmt.Printf("\nDownload of %s completed successfully\n", modelName)
}
return modelPath, nil
}
return "", fmt.Errorf("failed to download %s after %d attempts", modelName, options.MaxRetries)
}
type hfFile struct {
Type string `json:"type"`
Path string `json:"path"`
IsDirectory bool
}
func validateDownloadHfModel(modelPath string, branch string, authToken string) error {
if strings.Contains(modelPath, ":") {
return errors.New("model filters are not supported")
}
client := &http.Client{}
hasTokenizer, hasOnxx, err := checkURL(client, fmt.Sprintf("https://huggingface.co/api/models/%s/tree/%s", modelPath, branch), authToken)
if err != nil {
return err
}
var errs []error
if !hasOnxx {
errs = append(errs, fmt.Errorf("model does not have a model.onnx file, Hugot only works with onnx models"))
}
if !hasTokenizer {
errs = append(errs, fmt.Errorf("model does not have a tokenizer.json file"))
}
return errors.Join(errs...)
}
func checkURL(client *http.Client, url string, authToken string) (bool, bool, error) {
var tokenizerFound bool
var onnxFound bool
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return false, false, err
}
if authToken != "" {
req.Header.Add("Authorization", "Bearer "+authToken)
}
resp, err := client.Do(req)
if err != nil {
return false, false, err
}
defer func(resp *http.Response) {
err = errors.Join(err, resp.Body.Close())
}(resp)
var filesList []hfFile
e := json.NewDecoder(resp.Body).Decode(&filesList)
if e != nil {
return false, false, e
}
var dirs []hfFile
for _, f := range filesList {
if filepath.Base(f.Path) == "tokenizer.json" {
tokenizerFound = true
}
if filepath.Ext(f.Path) == ".onnx" {
onnxFound = true
}
if f.Type == "directory" {
// Do dirs later if files not found at this level
dirs = append(dirs, f)
}
if onnxFound && tokenizerFound {
break
}
}
if !(onnxFound && tokenizerFound) {
for _, dir := range dirs {
tokenizerFoundRec, onnxFoundRec, dirErr := checkURL(client, url+"/"+dir.Path, authToken)
if dirErr != nil {
return false, false, dirErr
}
tokenizerFound = tokenizerFound || tokenizerFoundRec
onnxFound = onnxFound || onnxFoundRec
if onnxFound && tokenizerFound {
break
}
}
}
return tokenizerFound, onnxFound, nil
}