Skip to content

fix: Support platform requirements in generated requirements files #2302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/vincent-petithory/dataurl v1.0.0
github.com/xeipuuv/gojsonschema v1.2.0
github.com/xeonx/timeago v1.0.0-rc5
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/sync v0.13.0
golang.org/x/sys v0.32.0
Expand Down Expand Up @@ -271,6 +270,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.37.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/net v0.39.0 // indirect
Expand Down
68 changes: 47 additions & 21 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,29 +273,37 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
return images[0].ImageTag(), nil
}

func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
func tfGPUPackage(ver string, cuda string) (PythonRequirement, error) {
for _, compat := range TFCompatibilityMatrix {
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
name, cpuVersion, _, _, err = SplitPinnedPythonRequirement(compat.TFGPUPackage)
return name, cpuVersion, err
if req := SplitPinnedPythonRequirement(compat.TFGPUPackage); !req.ParsedFieldsValid {
return PythonRequirement{}, fmt.Errorf("Invalid Python requirement for %s version %s", ver, cuda)
} else {
return req, nil
}
}
}
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA(), so fail silently
return "", "", nil
return PythonRequirement{}, nil
}

func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchCPUPackage(ver, goos, goarch string) (req PythonRequirement, err error) {
req.Name = "torch"
req.Version = ver
req.ParsedFieldsValid = true

// The default is to just install the default version. For older pytorch versions, they don't have any CPU versions.
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchVersion() == ver && compat.CUDA == nil {
return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
req.Version = torchStripCPUSuffixForM1(compat.Torch, goos, goarch)
req.FindLinks = []string{compat.FindLinks}
req.ExtraIndexURLs = []string{compat.ExtraIndexURL}
}
}

// Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions.
return "torch", ver, "", "", nil
return
}

func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchGPUPackage(ver string, cuda string) (req PythonRequirement, err error) {
// find the torch package that has the requested torch version and the latest cuda version
// that is at most as high as the requested cuda version
var latest *TorchCompatibility
Expand Down Expand Up @@ -324,25 +332,36 @@ func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extr
}
}
}
if latest == nil {
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
return "torch", ver, "", "", nil

req.Name = "torch"
req.ParsedFieldsValid = true
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
if latest != nil {
req.Version = version.StripModifier(latest.Torch)
req.FindLinks = []string{latest.FindLinks}
req.ExtraIndexURLs = []string{latest.ExtraIndexURL}
}

return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
return
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchvisionCPUPackage(ver, goos, goarch string) (req PythonRequirement, err error) {
req.Name = "torchvision"
req.ParsedFieldsValid = true

// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
req.Version = ver
for _, compat := range TorchCompatibilityMatrix {
if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
req.Version = torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch)
req.FindLinks = []string{compat.FindLinks}
req.ExtraIndexURLs = []string{compat.ExtraIndexURL}
}
}
// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
return "torchvision", ver, "", "", nil
return
}

func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
func torchvisionGPUPackage(ver, cuda string) (req PythonRequirement, err error) {
// find the torchvision package that has the requested
// torchvision version and the latest cuda version that is at
// most as high as the requested cuda version
Expand Down Expand Up @@ -371,13 +390,20 @@ func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extra
}
}
}

req.Name = "torchvision"
req.ParsedFieldsValid = true
if latest == nil {
// TODO: can we suggest a CUDA version known to be compatible?
console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver)
return "torchvision", ver, "", "", nil
req.Version = ver
} else {
req.Version = version.StripModifier(latest.Torchvision)
req.FindLinks = []string{latest.FindLinks}
req.ExtraIndexURLs = []string{latest.ExtraIndexURL}
}

return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
return
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
Expand Down
185 changes: 82 additions & 103 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,12 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,
return "", "", "", nil
}

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
func (c *Config) pythonPackageVersion(name string) (string, bool) {
for _, pkg := range c.Build.pythonRequirementsContent {
pkgName, version, _, _, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// package is not in package==version format
continue
}
if pkgName == name {
return version, true
if req := SplitPinnedPythonRequirement(pkg); !req.ParsedFieldsValid {
return "", false
} else if req.Name == name {
return req.Version, true
}
}
return "", false
Expand Down Expand Up @@ -327,132 +324,114 @@ func (c *Config) ValidateAndComplete(projectDir string) error {
}

// PythonRequirementsForArch returns a requirements.txt file with all the GPU packages resolved for given OS and architecture.
// The packages listed in c.Build.pythonRequirementsContent are user-supplied requirements. Packages listed in the
// `includePackages` parameter are defaults. The two sets are union'd together, with the user's own requirements
// taking precedence (version, find-links, etc) if there is a duplicate.
//
// The method will return the string content of the requirements file, or an error.
func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePackages []string) (string, error) {
packages := []string{}
findLinksSet := map[string]bool{}
extraIndexURLSet := map[string]bool{}

includePackageNames := []string{}
for _, pkg := range includePackages {
packageName, err := PackageName(pkg)
// First, parse all the incoming requirements into PythonRequirements
userRequirements := ParseRequirements(c.Build.pythonRequirementsContent, 0)

// Do the same for the packages we've been asked to include by default, but set their ordering keys using a
// sequence number later than the user requirements. This will ensure that our default requirements come at the
// end of the list, and the order is maintained.
includeRequirements := ParseRequirements(includePackages, len(userRequirements))

// For the user requirements, update them for the given OS and architecture
var err error
for i, req := range userRequirements {
// We're only interested in requirements that we were actually able to parse
if !req.ParsedFieldsValid {
continue
}
userRequirements[i], err = c.pythonPackageForArch(req, goos, goarch)
if err != nil {
return "", err
}
includePackageNames = append(includePackageNames, packageName)
}

// Include all the requirements and remove our include packages if they exist
for _, pkg := range c.Build.pythonRequirementsContent {
archPkg, findLinksList, extraIndexURLs, err := c.pythonPackageForArch(pkg, goos, goarch)
if err != nil {
return "", err
}
packages = append(packages, archPkg)
if len(findLinksList) > 0 {
for _, fl := range findLinksList {
findLinksSet[fl] = true
}
}
if len(extraIndexURLs) > 0 {
for _, u := range extraIndexURLs {
extraIndexURLSet[u] = true
}
}
// We're about to perform deduplication between the user requirements and the provided defaults. There may
// be user requirements that we weren't able to parse though - we will keep a note of those so that we can
// add them back in later.
unparsed := make([]PythonRequirement, 0)

packageName, _ := PackageName(archPkg)
if packageName != "" {
foundIdx := -1
for i, includePkg := range includePackageNames {
if includePkg == packageName {
foundIdx = i
break
}
}
if foundIdx != -1 {
includePackageNames = append(includePackageNames[:foundIdx], includePackageNames[foundIdx+1:]...)
includePackages = append(includePackages[:foundIdx], includePackages[foundIdx+1:]...)
}
}
// Next, build a map of requirements keyed on the requirement name. We'll init this with the requirements
// from `includePackages`, and update it with the user's requirements (which may therefore overwrite the defaults).
finalRequirementsMap := make(map[string]PythonRequirement)
for _, req := range includeRequirements {
finalRequirementsMap[req.Name] = req
}

// If we still have some include packages add them in
packages = append(packages, includePackages...)

// Create final requirements.txt output
// Put index URLs first
lines := []string{}
for findLinks := range findLinksSet {
lines = append(lines, "--find-links "+findLinks)
for _, req := range userRequirements {
if req.ParsedFieldsValid {
finalRequirementsMap[req.Name] = req
} else {
unparsed = append(unparsed, req)
}
}
for extraIndexURL := range extraIndexURLSet {
lines = append(lines, "--extra-index-url "+extraIndexURL)

// Now we can build a real PythonRequirements from the values of the finalRequirementsMap
finalRequirements := make(PythonRequirements, 0, len(finalRequirementsMap)+len(unparsed))
for _, req := range finalRequirementsMap {
finalRequirements = append(finalRequirements, req)
}

// Then, everything else
lines = append(lines, packages...)
// Add the unparsed requirements back in
finalRequirements = append(finalRequirements, unparsed...)

return strings.Join(lines, "\n"), nil
return finalRequirements.RequirementsFileContent(), nil
}

// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
}
if len(extraIndexURLs) > 0 {
return name + "==" + version, findLinksList, extraIndexURLs, nil
}

extraIndexURL := ""
findLinks := ""
switch name {
// pythonPackageForArch takes a PythonRequirement and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture. If
// the package is not one of the ones whose version we manage, we return the original requirement.
func (c *Config) pythonPackageForArch(req PythonRequirement, goos, goarch string) (out PythonRequirement, err error) {
switch req.Name {
case "tensorflow":
if c.Build.GPU {
name, version, err = tfGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = tfGPUPackage(req.Version, c.Build.CUDA)
}
// There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed
case "torch":
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = torchGPUPackage(req.Version, c.Build.CUDA)
} else {
name, version, findLinks, extraIndexURL, err = torchCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
out, err = torchCPUPackage(req.Version, goos, goarch)
}
case "torchvision":
if c.Build.GPU {
name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA)
if err != nil {
return "", nil, nil, err
}
out, err = torchvisionGPUPackage(req.Version, c.Build.CUDA)
} else {
name, version, findLinks, extraIndexURL, err = torchvisionCPUPackage(version, goos, goarch)
if err != nil {
return "", nil, nil, err
}
out, err = torchvisionCPUPackage(req.Version, goos, goarch)
}
default:
out = req
}

if err != nil {
return PythonRequirement{}, err
}

// Regardless of whether we're using the original or generated requirement, we bring across some user-supplied
// attributes if provided.
out.order = req.order

// We treat version slightly differently, because we may have rewritten the field to include the cpu specifier.
// Therefore, we will only overwrite the output version if the output version is currently empty.
if req.Version != "" && out.Version == "" {
out.Version = req.Version
}
pkgWithVersion := name
if version != "" {
pkgWithVersion += "==" + version
if req.EnvironmentMarkers != "" {
out.EnvironmentMarkers = req.EnvironmentMarkers
}
if extraIndexURL != "" {
extraIndexURLs = []string{extraIndexURL}
if len(req.FindLinks) > 0 {
out.FindLinks = req.FindLinks
}
if findLinks != "" {
findLinksList = []string{findLinks}
if len(req.ExtraIndexURLs) > 0 {
out.ExtraIndexURLs = req.ExtraIndexURLs
}
return pkgWithVersion, findLinksList, extraIndexURLs, nil
return
}

func ValidateCudaVersion(cudaVersion string) error {
Expand Down
20 changes: 20 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,26 @@ func TestBlankBuild(t *testing.T) {
require.Equal(t, false, config.Build.GPU)
}

// TestPythonRequirementsForArchWithPlatform checks that generated requirements don't lose any metadata that
// was supplied in the original requirements.txt, such as platform restrictions. We do expect hashes to be dropped.
func TestPythonRequirementsForArchWithPlatform(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`pywin32==310 ; sys_platform == 'win32' \
--hash=sha256:126298077a9d7c95c53823934f000599f66ec9296b09167810eb24875f32689c`), 0o644)
require.NoError(t, err)
config := &Config{
Build: &Build{
PythonVersion: "3.8",
PythonRequirements: "requirements.txt",
},
}
require.NoError(t, config.ValidateAndComplete(tmpDir))
requirements, err := config.PythonRequirementsForArch("", "", []string{})
require.NoError(t, err)
expected := "pywin32==310 ; sys_platform == 'win32'"
require.Equal(t, expected, requirements)
}

func TestPythonRequirementsForArchWithAddedPackage(t *testing.T) {
config := &Config{
Build: &Build{
Expand Down
Loading