Skip to content

Commit

Permalink
feat: eliminate redundant loads of predict.py and train.py in early s…
Browse files Browse the repository at this point in the history
…etup of cog predict (#1503)

- store openapi schema file json & py files in .cog subdir during cog build;
- conditionally create Input & Output schema types from openapi schema file;
- conditionally create TrainingInput & TrainingOutput schema types from openapi schema file;

Migration: optimization will work for new models, old models are unaffected

Signed-off-by: Dmitri Khokhlov <[email protected]>
  • Loading branch information
dkhokhlov authored Feb 16, 2024
1 parent 9f971d2 commit 064fcee
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ test-integration: cog

.PHONY: test-python
test-python:
$(PYTEST) -n auto -vv python/tests $(if $(FILTER),-k "$(FILTER)",)
$(PYTEST) -n auto -vv --cov=python/cog --cov-report term-missing python/tests $(if $(FILTER),-k "$(FILTER)",)

.PHONY: test
test: test-go test-python test-integration
Expand Down
11 changes: 9 additions & 2 deletions pkg/docker/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Build(dir, dockerfile, imageName string, secrets []string, noCache bool, pr
return cmd.Run()
}

func BuildAddLabelsToImage(image string, labels map[string]string) error {
func BuildAddLabelsAndSchemaToImage(image string, labels map[string]string, bundledSchemaFile string, bundledSchemaPy string) error {
var args []string

args = append(args,
Expand All @@ -74,7 +74,14 @@ func BuildAddLabelsToImage(image string, labels map[string]string) error {
args = append(args, ".")
cmd := exec.Command("docker", args...)

dockerfile := "FROM " + image
dockerfile := "FROM " + image + "\n"
dockerfile += "COPY " + bundledSchemaFile + " .cog\n"
env_path := "/tmp/venv/tools/"
dockerfile += "RUN python -m venv --symlinks " + env_path + " && " +
env_path + "/bin/python -m pip install 'datamodel-code-generator>=0.25' && " +
env_path + "/bin/datamodel-codegen --version && " +
env_path + "/bin/datamodel-codegen --input-file-type openapi --input " + bundledSchemaFile +
" --output " + bundledSchemaPy + " && rm -rf " + env_path
cmd.Stdin = strings.NewReader(dockerfile)

console.Debug("$ " + strings.Join(cmd.Args, " "))
Expand Down
14 changes: 13 additions & 1 deletion pkg/image/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ import (

const dockerignoreBackupPath = ".dockerignore.cog.bak"
const weightsManifestPath = ".cog/cache/weights_manifest.json"
const bundledSchemaFile = ".cog/openapi_schema.json"
const bundledSchemaPy = ".cog/schema.py"

// Build a Cog model from a config
//
// This is separated out from docker.Build(), so that can be as close as possible to the behavior of 'docker build'.
func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string, dockerfileFile string) error {
console.Infof("Building Docker image from environment in cog.yaml as %s...", imageName)

// remove bundled schema files that may be left from previous builds
_ = os.Remove(bundledSchemaFile)
_ = os.Remove(bundledSchemaPy)

if dockerfileFile != "" {
dockerfileContents, err := os.ReadFile(dockerfileFile)
if err != nil {
Expand Down Expand Up @@ -113,6 +119,12 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
schemaJSON = data
}

// save open_api schema file
err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644)
if err != nil {
return fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err)
}

loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
doc, err := loader.LoadFromData(schemaJSON)
Expand Down Expand Up @@ -163,7 +175,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
}
}

if err := docker.BuildAddLabelsToImage(imageName, labels); err != nil {
if err := docker.BuildAddLabelsAndSchemaToImage(imageName, labels, bundledSchemaFile, bundledSchemaPy); err != nil {
return fmt.Errorf("Failed to add labels to image: %w", err)
}
return nil
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ optional-dependencies = { "dev" = [
"pytest-httpserver",
"pytest-rerunfailures",
"pytest-xdist",
"pytest-cov",
"responses",
"ruff",
] }
Expand Down
20 changes: 20 additions & 0 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import importlib.util
import os
import os.path
import sys
import typing as t
from datetime import datetime
from enum import Enum
from types import ModuleType

import pydantic

BUNDLED_SCHEMA_PATH = ".cog/schema.py"


class Status(str, Enum):
STARTING = "starting"
Expand Down Expand Up @@ -92,3 +99,16 @@ class TrainingRequest(PredictionRequest):

class TrainingResponse(PredictionResponse):
pass


def create_schema_module() -> t.Optional[ModuleType]:
if not os.path.exists(BUNDLED_SCHEMA_PATH):
return None
name = "cog.bundled_schema"
spec = importlib.util.spec_from_file_location(name, BUNDLED_SCHEMA_PATH)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[name] = module
spec.loader.exec_module(module)
return module
28 changes: 20 additions & 8 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,16 @@ async def start_shutdown() -> Any:

try:
predictor_ref = get_predictor_ref(config, mode)
# TODO: avoid loading predictor code in this process
predictor = load_predictor_from_ref(predictor_ref)
InputType = get_input_type(predictor)
OutputType = get_output_type(predictor)
# use bundled schema if it exists
schema_module = schema.create_schema_module()
if schema_module is not None:
log.info("using bundled schema")
InputType = schema_module.Input
OutputType = schema_module.Output
else:
predictor = load_predictor_from_ref(predictor_ref)
InputType = get_input_type(predictor)
OutputType = get_output_type(predictor)
except Exception:
msg = "Error while loading predictor:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
Expand Down Expand Up @@ -166,10 +172,16 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T":
if "train" in config:
try:
trainer_ref = get_predictor_ref(config, "train")
# TODO: avoid loading trainer code in this process
trainer = load_predictor_from_ref(trainer_ref)
TrainingInputType = get_training_input_type(trainer)
TrainingOutputType = get_training_output_type(trainer)
# use bundled schema if it exists
schema_module = schema.create_schema_module()
if schema_module is not None:
log.info("using bundled schema")
TrainingInputType = schema_module.TrainingInputType
TrainingOutputType = schema_module.TrainingOutputType
else:
trainer = load_predictor_from_ref(trainer_ref)
TrainingInputType = get_training_input_type(trainer)
TrainingOutputType = get_training_output_type(trainer)

class TrainingRequest(
schema.TrainingRequest.with_types(input_type=TrainingInputType)
Expand Down

0 comments on commit 064fcee

Please sign in to comment.