Skip to content

Commit

Permalink
Add fast generator for cog build (#2108)
Browse files Browse the repository at this point in the history
* Add fast generator

* Add copy weights step

* Add package installation step

* Add user cache to the monobase

* Resolve monobase caching into the local
userspace to allow for monobase caching between
builds

* Add rsync copy

* Check if weight files have changed before checksum

* Split install into separate steps

* We can get the tarballs for each of these layers
instead of creating 1 big tarball.

* Create requirements.txt in the build tmp directory

* Fix unit tests

* Fix lint

* Add basic unit tests

* Use UV_CACHE_DIR and mount uv cache

* Remove —skip-cuda from monobase build

* Monobase now handles empty CUDA env vars

* Fix file not found when evaluating weights

* Add UV_LINK_MODE=copy to the uv install commands

* Add UV_COMPILE_BYTECODE env var

* Remove verbosity from monobase exec

* Fix integration test

* Switch tini and exec

---------

Signed-off-by: Will Sackfield <[email protected]>
  • Loading branch information
8W9aG authored Jan 15, 2025
1 parent 6eb2d2e commit 9efb306
Show file tree
Hide file tree
Showing 26 changed files with 913 additions and 83 deletions.
2 changes: 1 addition & 1 deletion pkg/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func cmdDockerfile(cmd *cobra.Command, args []string) error {
return err
}

generator, err := dockerfile.NewGenerator(cfg, projectDir)
generator, err := dockerfile.NewGenerator(cfg, projectDir, false)
if err != nil {
return fmt.Errorf("Error creating Dockerfile generator: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
for _, compat := range TFCompatibilityMatrix {
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
name, cpuVersion, _, _, err = splitPinnedPythonRequirement(compat.TFGPUPackage)
name, cpuVersion, _, _, err = SplitPinnedPythonRequirement(compat.TFGPUPackage)
return name, cpuVersion, err
}
}
Expand Down
64 changes: 8 additions & 56 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
for _, pkg := range c.Build.pythonRequirementsContent {
pkgName, version, _, _, err := splitPinnedPythonRequirement(pkg)
pkgName, version, _, _, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// package is not in package==version format
continue
Expand Down Expand Up @@ -331,7 +331,11 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa

includePackageNames := []string{}
for _, pkg := range includePackages {
includePackageNames = append(includePackageNames, packageName(pkg))
packageName, err := PackageName(pkg)
if err != nil {
return "", err
}
includePackageNames = append(includePackageNames, packageName)
}

// Include all the requirements and remove our include packages if they exist
Expand All @@ -352,7 +356,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
}
}

packageName := packageName(archPkg)
packageName, _ := PackageName(archPkg)
if packageName != "" {
foundIdx := -1
for i, includePkg := range includePackageNames {
Expand Down Expand Up @@ -390,7 +394,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
name, version, findLinksList, extraIndexURLs, err := splitPinnedPythonRequirement(pkg)
name, version, findLinksList, extraIndexURLs, err := SplitPinnedPythonRequirement(pkg)
if err != nil {
// It's not pinned, so just return the line verbatim
return pkg, []string{}, []string{}, nil
Expand Down Expand Up @@ -562,50 +566,6 @@ Compatible cuDNN version is: %s`, c.Build.CuDNN, tfVersion, tfCuDNN)
return nil
}

// splitPythonPackage returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
func splitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)

matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
if matches == nil {
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
}

nameFound := false
versionFound := false

for _, match := range matches {
if match[1] != "" {
name = match[1]
nameFound = true
}

if match[2] != "" {
version = match[2]
versionFound = true
}

if match[3] != "" {
findLinks = append(findLinks, match[3])
}

if match[4] != "" {
findLinks = append(findLinks, match[4])
}

if match[5] != "" {
extraIndexURLs = append(extraIndexURLs, match[5])
}
}

if !nameFound || !versionFound {
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
}

return name, version, findLinks, extraIndexURLs, nil
}

func sliceContains(slice []string, s string) bool {
for _, el := range slice {
if el == s {
Expand All @@ -614,11 +574,3 @@ func sliceContains(slice []string, s string) bool {
}
return false
}

func packageName(pipRequirement string) string {
match := PipPackageNameRegex.FindStringSubmatch(pipRequirement)
if len(match) <= 1 {
return ""
}
return match[1]
}
2 changes: 1 addition & 1 deletion pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ func TestSplitPinnedPythonRequirement(t *testing.T) {
}

for _, tc := range testCases {
name, version, findLinks, extraIndexURLs, err := splitPinnedPythonRequirement(tc.input)
name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input)

if tc.expectedError {
require.Error(t, err)
Expand Down
130 changes: 130 additions & 0 deletions pkg/config/requirements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package config

import (
"bufio"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
)

func GenerateRequirements(tmpDir string, config *Config) (string, error) {
// Deduplicate packages between the requirements.txt and the python packages directive.
packageNames := make(map[string]string)

// Read the python packages configuration.
for _, requirement := range config.Build.PythonPackages {
packageName, err := PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}

// Read the python requirements.
if config.Build.PythonRequirements != "" {
fh, err := os.Open(config.Build.PythonRequirements)
if err != nil {
return "", err
}
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
requirement := scanner.Text()
packageName, err := PackageName(requirement)
if err != nil {
return "", err
}
packageNames[packageName] = requirement
}
}

// If we don't have any packages skip further processing
if len(packageNames) == 0 {
return "", nil
}

// Sort the package names by alphabetical order.
keys := make([]string, 0, len(packageNames))
for k := range packageNames {
keys = append(keys, k)
}
sort.Strings(keys)

// Render the expected contents
requirementsContent := ""
for _, k := range keys {
requirementsContent += packageNames[k] + "\n"
}

// Check against the old requirements contents
requirementsFile := filepath.Join(tmpDir, "requirements.txt")
_, err := os.Stat(requirementsFile)
if !errors.Is(err, os.ErrNotExist) {
bytes, err := os.ReadFile(requirementsFile)
if err != nil {
return "", err
}
oldRequirementsContents := string(bytes)
if oldRequirementsContents == requirementsFile {
return requirementsFile, nil
}
}

// Write out a new requirements file
err = os.WriteFile(requirementsFile, []byte(requirementsContent), 0o644)
if err != nil {
return "", err
}
return requirementsFile, nil
}

// SplitPinnedPythonRequirement returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line
// in the form name==version [--find-links=<findLink>] [-f <findLink>] [--extra-index-url=<extraIndexURL>]
func SplitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) {
pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`)

matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1)
if matches == nil {
return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement)
}

nameFound := false
versionFound := false

for _, match := range matches {
if match[1] != "" {
name = match[1]
nameFound = true
}

if match[2] != "" {
version = match[2]
versionFound = true
}

if match[3] != "" {
findLinks = append(findLinks, match[3])
}

if match[4] != "" {
findLinks = append(findLinks, match[4])
}

if match[5] != "" {
extraIndexURLs = append(extraIndexURLs, match[5])
}
}

if !nameFound || !versionFound {
return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement)
}

return name, version, findLinks, extraIndexURLs, nil
}

func PackageName(pipRequirement string) (string, error) {
name, _, _, _, err := SplitPinnedPythonRequirement(pipRequirement)
return name, err
}
21 changes: 21 additions & 0 deletions pkg/config/requirements_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package config

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestGenerateRequirements(t *testing.T) {
tmpDir := t.TempDir()
build := Build{
PythonPackages: []string{"torch==2.5.1"},
}
config := Config{
Build: &build,
}
requirementsFile, err := GenerateRequirements(tmpDir, &config)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile)
}
12 changes: 9 additions & 3 deletions pkg/docker/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ import (
"strings"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/dockerfile"

"github.com/replicate/cog/pkg/util"
"github.com/replicate/cog/pkg/util/console"
)

func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
func Build(dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, epoch int64) error {
var args []string

userCache, err := dockerfile.UserCache()
if err != nil {
return err
}

args = append(args,
"buildx", "build",
"buildx", "build", "--build-context", "usercache="+userCache,
)

if util.IsAppleSiliconMac(runtime.GOOS, runtime.GOARCH) {
Expand Down Expand Up @@ -65,7 +71,7 @@ func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, pr
cmd.Dir = dir
cmd.Stdout = os.Stderr // redirect stdout to stderr - build output is all messaging
cmd.Stderr = os.Stderr
cmd.Stdin = strings.NewReader(dockerfile)
cmd.Stdin = strings.NewReader(dockerfileContents)

console.Debug("$ " + strings.Join(cmd.Args, " "))
return cmd.Run()
Expand Down
2 changes: 1 addition & 1 deletion pkg/dockerfile/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) {
return "", err
}

generator, err := NewGenerator(conf, "")
generator, err := NewGenerator(conf, "", false)
if err != nil {
return "", err
}
Expand Down
33 changes: 33 additions & 0 deletions pkg/dockerfile/build_tempdir.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package dockerfile

import (
"os"
"path"
"time"
)

func BuildCogTempDir(dir string) (string, error) {
rootTmp := path.Join(dir, ".cog/tmp")
if err := os.MkdirAll(rootTmp, 0o755); err != nil {
return "", err
}
return rootTmp, nil
}

func BuildTempDir(dir string) (string, error) {
rootTmp, err := BuildCogTempDir(dir)
if err != nil {
return "", err
}

if err := os.MkdirAll(rootTmp, 0o755); err != nil {
return "", err
}
// tmpDir ends up being something like dir/.cog/tmp/build20240620123456.000000
now := time.Now().Format("20060102150405.000000")
tmpDir, err := os.MkdirTemp(rootTmp, "build"+now)
if err != nil {
return "", err
}
return tmpDir, nil
}
15 changes: 15 additions & 0 deletions pkg/dockerfile/build_tempdir_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dockerfile

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestBuildCogTempDir(t *testing.T) {
tmpDir := t.TempDir()
cogTmpDir, err := BuildCogTempDir(tmpDir)
require.NoError(t, err)
require.Equal(t, filepath.Join(tmpDir, ".cog/tmp"), cogTmpDir)
}
6 changes: 6 additions & 0 deletions pkg/dockerfile/cog_embed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package dockerfile

import "embed"

//go:embed embed/*.whl
var CogEmbed embed.FS
Loading

0 comments on commit 9efb306

Please sign in to comment.