Skip to content

Commit

Permalink
feat: new approach for fast validation - fast cold boot (#1553)
Browse files Browse the repository at this point in the history
- fast cold boot;
- new approach for fast validation (elimination of redundant heavy imports of predict/train py files into cog at runtime);
- generates stripped down (slim) versions of predict/train py files using dynamic Python code introspection;
- Pydantic model creation is based on Python function signatures in slim versions of predict/train py files, and not using  openapi schema as in previous attempt in feat: eliminate redundant loads of predict.py and train.py in early setup of cog predict #1503;
- now 100% compatible with existing and future cog types and existing code;
- removed schema py generation and bundling during cog build;
- supports existing models (no need to rebuild) as openapi schema is not used anymore;
- will fallback to slow loader in cases when predict/train function signatures refer to slow imports, like dynamically calculated defaults.
  • Loading branch information
dkhokhlov authored Mar 30, 2024
1 parent 1999f43 commit fae2ac3
Show file tree
Hide file tree
Showing 20 changed files with 502 additions and 56 deletions.
6 changes: 0 additions & 6 deletions pkg/docker/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ func BuildAddLabelsAndSchemaToImage(image string, labels map[string]string, bund

dockerfile := "FROM " + image + "\n"
dockerfile += "COPY " + bundledSchemaFile + " .cog\n"
env_path := "/tmp/venv/tools/"
dockerfile += "RUN python -m venv --symlinks " + env_path + " && " +
env_path + "/bin/python -m pip install 'datamodel-code-generator>=0.25' && " +
env_path + "/bin/datamodel-codegen --version && " +
env_path + "/bin/datamodel-codegen --input-file-type openapi --input " + bundledSchemaFile +
" --output " + bundledSchemaPy + " && rm -rf " + env_path
cmd.Stdin = strings.NewReader(dockerfile)

console.Debug("$ " + strings.Join(cmd.Args, " "))
Expand Down
258 changes: 258 additions & 0 deletions python/cog/code_xforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import ast
import re
import types
from typing import Optional, Set, Union

COG_IMPORT_MODULES = {"cog", "typing", "sys", "os", "functools", "pydantic", "numpy"}


def load_module_from_string(
name: str, source: Union[str, None]
) -> Union[types.ModuleType, None]:
if not source or not name:
return None
module = types.ModuleType(name)
exec(source, module.__dict__) # noqa: S102
return module


def extract_class_source(source_code: str, class_name: str) -> str:
"""
Extracts the source code for a specified class from a given source text.
Args:
source_code: The complete source code as a string.
class_name: The name of the class to extract.
Returns:
The source code of the specified class if found; otherwise, an empty string.
"""
class_name_pattern = re.compile(r"\b[A-Z]\w*\b")
all_class_names = class_name_pattern.findall(class_name)

class ClassExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self.class_source = None

def visit_ClassDef(self, node: ast.ClassDef) -> None:
if node.name in all_class_names:
self.class_source = ast.get_source_segment(source_code, node)

tree = ast.parse(source_code)
extractor = ClassExtractor()
extractor.visit(tree)
return extractor.class_source if extractor.class_source else ""


def extract_function_source(source_code: str, function_name: str) -> str:
"""
Extracts the source code for a specified function from a given source text.
Args:
source_code: The complete source code as a string.
function_name: The name of the function to extract.
Returns:
The source code of the specified function if found; otherwise, an empty string.
"""

class FunctionExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self.function_source = None

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == function_name and not isinstance(node, ast.Module):
# Extract the source segment for this function definition
self.function_source = ast.get_source_segment(source_code, node)

tree = ast.parse(source_code)
extractor = FunctionExtractor()
extractor.visit(tree)
return extractor.function_source if extractor.function_source else ""


def make_class_methods_empty(source_code: Union[str, ast.AST], class_name: str) -> str:
"""
Transforms the source code of a specified class to remove the bodies of all its methods
and replace them with 'return None'.
Args:
source_code: The complete source code as a string.
class_name: The name of the class to transform.
Returns:
The transformed source code of the specified class.
"""

class MethodBodyTransformer(ast.NodeTransformer):
def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.AST]:
if node.name == class_name:
for body_item in node.body:
if isinstance(body_item, ast.FunctionDef):
# Replace the body of the method with `return None`
body_item.body = [ast.Return(value=ast.Constant(value=None))]
return node

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
transformer = MethodBodyTransformer()
transformed_tree = transformer.visit(tree)
class_code = ast.unparse(transformed_tree)
return class_code


def extract_method_return_type(
source_code: Union[str, ast.AST], class_name: str, method_name: str
) -> Optional[str]:
"""
Extracts the return type annotation of a specified method within a given class from the source code.
Args:
source_code: A string containing the Python source code.
class_name: The name of the class containing the method of interest.
method_name: The name of the method whose return type annotation is to be extracted.
Returns:
A string representing the method's return type annotation if found; otherwise, None.
"""

class MethodReturnTypeExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self.return_type = None

def visit_ClassDef(self, node: ast.ClassDef) -> None:
if node.name == class_name:
self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == method_name and node.returns:
self.return_type = ast.unparse(node.returns)

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
extractor = MethodReturnTypeExtractor()
extractor.visit(tree)

return extractor.return_type


def extract_function_return_type(
source_code: Union[str, ast.AST], function_name: str
) -> Optional[str]:
"""
Extracts the return type annotation of a specified function from the source code.
Args:
source_code: A string containing the Python source code.
function_name: The name of the function whose return type annotation is to be extracted.
Returns:
A string representing the function's return type annotation if found; otherwise, None.
"""

class FunctionReturnTypeExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self.return_type = None

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == function_name and node.returns:
# Extract and return the string representation of the return type
self.return_type = ast.unparse(node.returns)

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
extractor = FunctionReturnTypeExtractor()
extractor.visit(tree)

return extractor.return_type


def make_function_empty(source_code: Union[str, ast.AST], function_name: str) -> str:
"""
Transforms the source code to remove the body of a specified function
and replace it with 'return None'.
Args:
source_code: The complete source code as a string or an AST node.
function_name: The name of the function to transform.
Returns:
The transformed source code with the specified function's body emptied.
"""

class FunctionBodyTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef) -> Optional[ast.AST]:
if node.name == function_name:
# Replace the body of the function with `return None`
node.body = [ast.Return(value=ast.Constant(value=None))]
return node

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
transformer = FunctionBodyTransformer()
transformed_tree = transformer.visit(tree)
modified_code = ast.unparse(transformed_tree)
return modified_code


def extract_specific_imports(
source_code: Union[str, ast.AST], module_names: Set[str]
) -> str:
"""
Extracts import statements from the source code that match a specified list of module names.
Args:
source_code: The Python source code as a string.
module_names: A set of module names for which to extract import statements.
Returns:
A list of strings, each string is an import statement that matches one of the specified module names.
"""

class ImportExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self.imports = []

def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
if alias.name in module_names:
self.imports.append(ast.unparse(node))

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module in module_names:
self.imports.append(ast.unparse(node))

tree = source_code if isinstance(source_code, ast.AST) else ast.parse(source_code)
extractor = ImportExtractor()
extractor.visit(tree)

return "\n".join(extractor.imports)


def strip_model_source_code(
source_code: str, class_name: str, method_name: str
) -> Optional[str]:
"""
Strips down the model source code by extracting relevant classes and making methods empty.
Args:
source_code: The complete model source code as a string.
class_name: The name of the class to be processed. If empty or the class is not found,
it falls back to processing a function specified by `method_name`.
method_name: The name of the method (if processing a class) or the function (if processing standalone functions)
whose return type is to be extracted and used in generating the final model source.
Returns:
A string containing the modified source code, including a predefined header.
Returns None if neither the class nor the function specified could be found or processed.
"""
imports = extract_specific_imports(source_code, COG_IMPORT_MODULES)
class_source = (
None if not class_name else extract_class_source(source_code, class_name)
)
if class_source:
class_source = make_class_methods_empty(class_source, class_name)
return_type = extract_method_return_type(class_source, class_name, method_name)
return_class_source = (
extract_class_source(source_code, return_type) if return_type else ""
)
model_source = (
imports + "\n\n" + return_class_source + "\n\n" + class_source + "\n"
)
else:
# use class_name specified in cog.yaml as method_name
method_name = class_name
function_source = extract_function_source(source_code, method_name)
if not function_source:
return None
function_source = make_function_empty(function_source, method_name)
if not function_source:
return None
return_type = extract_function_return_type(function_source, method_name)
return_class_source = (
extract_class_source(source_code, return_type) if return_type else ""
)
model_source = (
imports + "\n\n" + return_class_source + "\n\n" + function_source + "\n"
)
return model_source
58 changes: 54 additions & 4 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os.path
import sys
import types
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path
Expand All @@ -20,6 +21,10 @@
)
from unittest.mock import patch

import structlog

import cog.code_xforms as code_xforms

try:
from typing import get_args, get_origin
except ImportError: # Python < 3.8
Expand All @@ -44,6 +49,8 @@
Path as CogPath,
)

log = structlog.get_logger("cog.server.predictor")

ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath]


Expand Down Expand Up @@ -173,26 +180,69 @@ def get_predictor_ref(config: Dict[str, Any], mode: str = "predict") -> str:
return config[mode]


def load_predictor_from_ref(ref: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
def load_full_predictor_from_file(
module_path: str, module_name: str
) -> types.ModuleType:
spec = importlib.util.spec_from_file_location(module_name, module_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None

# Remove any sys.argv while importing predictor to avoid conflicts when
# user code calls argparse.Parser.parse_args in production
with patch("sys.argv", sys.argv[:1]):
spec.loader.exec_module(module)
return module


def load_slim_predictor_from_file(
module_path: str, class_name: str, method_name: str
) -> Union[types.ModuleType, None]:
with open(module_path, encoding="utf-8") as file:
source_code = file.read()
stripped_source = code_xforms.strip_model_source_code(
source_code, class_name, method_name
)
module = code_xforms.load_module_from_string(uuid.uuid4().hex, stripped_source)
return module


def get_predictor(module: types.ModuleType, class_name: str) -> Any:
predictor = getattr(module, class_name)
# It could be a class or a function
if inspect.isclass(predictor):
return predictor()
return predictor


def load_slim_predictor_from_ref(ref: str, method_name: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
module = None
try:
if sys.version_info >= (3, 9):
module = load_slim_predictor_from_file(module_path, class_name, method_name)
if not module:
log.debug(f"[{module_name}] fast loader returned None")
else:
log.debug(f"[{module_name}] cannot use fast loader as current Python <3.9")
except Exception as e:
log.debug(f"[{module_name}] fast loader failed: {e}")
finally:
if not module:
log.debug(f"[{module_name}] falling back to slow loader")
module = load_full_predictor_from_file(module_path, module_name)
predictor = get_predictor(module, class_name)
return predictor


def load_predictor_from_ref(ref: str) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
module = load_full_predictor_from_file(module_path, module_name)
predictor = get_predictor(module, class_name)
return predictor


# Base class for inputs, constructed dynamically in get_input_type().
# (This can't be a docstring or it gets passed through to the schema.)
class BaseInput(BaseModel):
Expand Down
Loading

0 comments on commit fae2ac3

Please sign in to comment.