Skip to content

Commit

Permalink
支持镜像站,支持下载main以外的branch,支持只下载某个文件夹
Browse files Browse the repository at this point in the history
  • Loading branch information
xieincz authored Mar 9, 2024
1 parent 6ca9820 commit 1717992
Showing 1 changed file with 24 additions and 61 deletions.
85 changes: 24 additions & 61 deletions src/huggingface-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"

"flag"
"os"
Expand All @@ -18,34 +17,40 @@ import (
"github.com/cheggaaa/pb/v3"
)

var huggingfaceHead string

func main() {
var url, targetParentFolder, proxyURLHead, homepage string
flag.StringVar(&url, "u", "", "huggingface url,such as: https://huggingface.co/datasets/Mizukiluke/ureader-instruction-1.0")
flag.StringVar(&url, "u", "", "huggingface url,such as: https://hf-mirror.com/Finnish-NLP/t5-large-nl36-finnish/tree/main")
flag.StringVar(&targetParentFolder, "f", "./", "path to your target folder")
flag.StringVar(&proxyURLHead, "p", "https://worker-share-proxy-01f5.xieincz.tk/", "proxy url")
flag.StringVar(&proxyURLHead, "p", "", "proxy url, leave it empty if you don't need it")
flag.StringVar(&homepage, "homepage", "https://github.com/xieincz/huggingface-go", "Homepage URL")
flag.Parse()

if url == "" {
flag.Usage()
return
}
if !strings.HasPrefix(url, "https://huggingface.co/") {
fmt.Printf("invalid url: %s\n", url)
return
}
if !strings.HasSuffix(proxyURLHead, "/") {
proxyURLHead += "/"
}

// 提取文件名和链接
// 使用 strings.TrimSuffix 函数去掉 "/tree/main"
modelURL := strings.TrimSuffix(url, "/tree/main/")
modelURL = strings.TrimSuffix(modelURL, "/tree/main")
modelURL = strings.TrimSuffix(modelURL, "/")
modelURL := strings.Split(url, "/tree/")[0]
branch := strings.Split(strings.Split(url, "/tree/")[1], "/")[0] //需要输入url必须含branch,否则后面会出问题
modelName := path.Base(modelURL)
tmp := strings.Split(url, branch+"/") //需要输入url末尾不含/,否则后面会出问题
var urlFolder string
if len(tmp) < 2 {
urlFolder = ""
} else {
urlFolder = tmp[1]
}

//提取出域名
tmp = strings.Split(url, "/")
huggingfaceHead = tmp[0] + "//" + tmp[2] //e.g. https://huggingface.co

fmt.Printf("model/datasets name: %s\n", modelName)
fmt.Printf("model/datasets url: %s\n", modelURL)
fmt.Printf("branch: %s\n", branch)

// 创建目标文件夹
targetFolder := path.Join(targetParentFolder, modelName)
Expand All @@ -59,7 +64,7 @@ func main() {
}
// 递归获取文件列表
fmt.Println("fetching file list... \nthis may take a while")
entries, err := fetchDirectoryEntriesRecursively(proxyURLHead, modelURL+"/tree/main", "")
entries, err := fetchDirectoryEntriesRecursively(proxyURLHead, modelURL+"/tree/"+branch, urlFolder)
if err != nil {
fmt.Printf("cannot fetch entries: %v\n", err)
return
Expand Down Expand Up @@ -91,9 +96,9 @@ func main() {
}
}
// 拼接文件下载链接
fileURL := modelURL + "/resolve/main/" + entry["path"].(string)
fileURL := modelURL + "/resolve/" + branch + "/" + entry["path"].(string)
//拼接文件下载代理链接
proxyFileURL := proxyURLHead + urlEncode(fileURL)
proxyFileURL := proxyURLHead + fileURL
// 下载文件并保存到目标文件夹
if err := downloadFileWithProgressBar(proxyFileURL, filePath, int(entry["size"].(float64))); err != nil {
fmt.Printf("cannot download file %s: %v\n", filePath, err)
Expand Down Expand Up @@ -128,7 +133,7 @@ func fetchDirectoryEntriesRecursively(proxyURLHead, baseURL, path string) ([]map
if path != "" {
url += "/" + path
}
proxyURL := proxyURLHead + urlEncode(url)
proxyURL := proxyURLHead + url
response, err := http.Get(proxyURL)
if err != nil {
return nil, err
Expand Down Expand Up @@ -171,48 +176,6 @@ func fetchDirectoryEntriesRecursively(proxyURLHead, baseURL, path string) ([]map
return res, nil
}

func urlEncode(s string) string {
return url.QueryEscape(encode(s))
}

// 将url中的其他语言的字符转义
var replaceDic = map[string]string{
"%2D": "-",
"%5F": "_",
"%2E": ".",
"%21": "!",
"%7E": "~",
"%2A": "*",
"%27": "'",
"%28": "(",
"%29": ")",
"%3B": ";",
"%3A": ":",
"%40": "@",
"%26": "&",
"%3D": "=",
"%2B": "+",
"%24": "$",
"%2C": ",",
"%2F": "/",
"%3F": "?",
"%23": "#",
"%25": "%",
}

func encode(s string) string {
u, err := url.Parse(s)
if err != nil {
fmt.Println("Error parsing URL: ", err)
return ""
}
u.Path = url.PathEscape(u.Path)
for k, v := range replaceDic {
u.Path = strings.ReplaceAll(u.Path, k, v)
}
return strings.ReplaceAll(u.String(), "%25", "%")
}

func extractEntries(dataProps, proxyURLHead string) ([]map[string]interface{}, error) {
var props map[string]interface{}
err := json.Unmarshal([]byte(dataProps), &props)
Expand All @@ -223,7 +186,7 @@ func extractEntries(dataProps, proxyURLHead string) ([]map[string]interface{}, e
nextURL := props["nextURL"]
fmt.Println("nextURL:", nextURL)
if nextURL != nil {
proxyURL := proxyURLHead + urlEncode("https://huggingface.co"+nextURL.(string))
proxyURL := proxyURLHead + huggingfaceHead + nextURL.(string)
response, err := http.Get(proxyURL)
if err != nil {
fmt.Println("Error:", err)
Expand Down

0 comments on commit 1717992

Please sign in to comment.