From 1289b6fe8c2ec01b7d47c6ee970838e5d4be4da7 Mon Sep 17 00:00:00 2001 From: Dmitri Khokhlov Date: Mon, 25 Mar 2024 13:34:42 -0700 Subject: [PATCH] feat: support for list in inputs (#1561) - added support for list[File] and list[Path]; - only in cog, still need to add list support in web UI; - updated docs/python.md; --- docs/python.md | 26 ++++++ pkg/cli/predict.go | 5 +- pkg/predict/input.go | 79 +++++++++++++------ pkg/predict/predictor.go | 2 +- python/cog/server/runner.py | 22 ++++-- .../fixtures/file-input-project/predict.py | 11 ++- .../fixtures/file-input-project/test.txt | 1 + .../fixtures/file-list-input-project/cog.yaml | 3 + .../file-list-input-project/predict.py | 14 ++++ .../cog.yaml | 0 .../fixtures/path-input-project/predict.py | 7 ++ .../fixtures/path-list-input-project/cog.yaml | 3 + .../path-list-input-project/predict.py | 9 +++ .../path-list-output-project/cog.yaml | 3 + .../predict.py | 0 .../cog.yaml | 0 .../predict.py | 0 .../{file-project => path-project}/cog.yaml | 0 .../{file-project => path-project}/predict.py | 0 .../test_integration/test_build.py | 2 +- .../test_integration/test_predict.py | 29 ++++++- 21 files changed, 174 insertions(+), 42 deletions(-) create mode 100644 test-integration/test_integration/fixtures/file-input-project/test.txt create mode 100644 test-integration/test_integration/fixtures/file-list-input-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/file-list-input-project/predict.py rename test-integration/test_integration/fixtures/{file-list-output-project => path-input-project}/cog.yaml (100%) create mode 100644 test-integration/test_integration/fixtures/path-input-project/predict.py create mode 100644 test-integration/test_integration/fixtures/path-list-input-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/path-list-input-project/predict.py create mode 100644 test-integration/test_integration/fixtures/path-list-output-project/cog.yaml rename test-integration/test_integration/fixtures/{file-list-output-project => path-list-output-project}/predict.py (100%) rename test-integration/test_integration/fixtures/{file-output-project => path-output-project}/cog.yaml (100%) rename test-integration/test_integration/fixtures/{file-output-project => path-output-project}/predict.py (100%) rename test-integration/test_integration/fixtures/{file-project => path-project}/cog.yaml (100%) rename test-integration/test_integration/fixtures/{file-project => path-project}/predict.py (100%) diff --git a/docs/python.md b/docs/python.md index 00772a180e..a83d1140ab 100644 --- a/docs/python.md +++ b/docs/python.md @@ -269,3 +269,29 @@ class Predictor(BasePredictor): upscaled_image.save(output_path) return Path(output_path) ``` + +## `List` + +The List type is also supported in inputs. It can hold any supported type. + +Example for **List[Path]**: +```py +class Predictor(BasePredictor): + def predict(self, paths: list[Path]) -> str: + output_parts = [] # Use a list to collect file contents + for path in paths: + with open(path) as f: + output_parts.append(f.read()) + return "".join(output_parts) +``` +The corresponding cog command: +```bash +$ echo test1 > 1.txt +$ echo test2 > 2.txt +$ cog predict -i paths=@1.txt -i paths=@2.txt +Running prediction... +test1 + +test2 +``` +- Note the repeated inputs with the same name "paths" which constitute the list diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 63afa35349..80a88edea1 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -288,7 +288,7 @@ func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openap } func parseInputFlags(inputs []string) (predict.Inputs, error) { - keyVals := map[string]string{} + keyVals := map[string][]string{} for _, input := range inputs { var name, value string @@ -305,7 +305,8 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) { value = value[1 : len(value)-1] } - keyVals[name] = value + // Append new values to the slice associated with the key + keyVals[name] = append(keyVals[name], value) } return predict.NewInputs(keyVals), nil diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 81f88bd536..0f42e3dc66 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,41 +1,41 @@ package predict import ( + "fmt" + "mime" "os" "path/filepath" "strings" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" - - "github.com/replicate/cog/pkg/util/console" - "github.com/replicate/cog/pkg/util/mime" ) type Input struct { String *string File *string + Array *[]any } type Inputs map[string]Input -func NewInputs(keyVals map[string]string) Inputs { +func NewInputs(keyVals map[string][]string) Inputs { input := Inputs{} - for key, val := range keyVals { - val := val - if strings.HasPrefix(val, "@") { - val = val[1:] - expandedVal, err := homedir.Expand(val) - if err != nil { - // FIXME: handle this better? - console.Warnf("Error expanding homedir: %s", err) + for key, vals := range keyVals { + if len(vals) == 1 { + val := vals[0] + if strings.HasPrefix(val, "@") { + val = val[1:] + input[key] = Input{File: &val} } else { - val = expandedVal + input[key] = Input{String: &val} } - - input[key] = Input{File: &val} - } else { - input[key] = Input{String: &val} + } else if len(vals) > 1 { + var anyVals = make([]any, len(vals)) + for i, v := range vals { + anyVals[i] = v + } + input[key] = Input{Array: &anyVals} } } return input @@ -55,19 +55,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { return input } -func (inputs *Inputs) toMap() (map[string]string, error) { - keyVals := map[string]string{} +func (inputs *Inputs) toMap() (map[string]any, error) { + keyVals := map[string]any{} for key, input := range *inputs { if input.String != nil { + // Directly assign the string value keyVals[key] = *input.String } else if input.File != nil { - content, err := os.ReadFile(*input.File) + // Single file handling: read content and convert to a data URL + dataURL, err := fileToDataURL(*input.File) if err != nil { return keyVals, err } - mimeType := mime.TypeByExtension(filepath.Ext(*input.File)) - keyVals[key] = dataurl.New(content, mimeType).String() + keyVals[key] = dataURL + } else if input.Array != nil { + // Handle array, potentially containing file paths + dataURLs := make([]string, len(*input.Array)) + for i, elem := range *input.Array { + if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") { + filePath := str[1:] // Remove '@' prefix + dataURL, err := fileToDataURL(filePath) + if err != nil { + return keyVals, err + } + dataURLs[i] = dataURL + } else if ok { + // Directly use the string if it's not a file path + dataURLs[i] = str + } + } + keyVals[key] = dataURLs } } return keyVals, nil } + +// Helper function to read file content and convert to a data URL +func fileToDataURL(filePath string) (string, error) { + // Expand home directory if necessary + expandedVal, err := homedir.Expand(filePath) + if err != nil { + return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err) + } + + content, err := os.ReadFile(expandedVal) + if err != nil { + return "", err + } + mimeType := mime.TypeByExtension(filepath.Ext(expandedVal)) + dataURL := dataurl.New(content, mimeType).String() + return dataURL, nil +} diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index 9f175087c5..bf422a4e7d 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -24,7 +24,7 @@ type HealthcheckResponse struct { type Request struct { // TODO: could this be Inputs? - Input map[string]string `json:"input"` + Input map[string]interface{} `json:"input"` } type Response struct { diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index e1ee0257b9..54e1360901 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -382,15 +382,21 @@ def _predict( input_dict = initial_prediction["input"] for k, v in input_dict.items(): - if isinstance(v, types.URLPath): - try: + try: + # Check if v is an instance of URLPath + if isinstance(v, types.URLPath): input_dict[k] = v.convert() - except requests.exceptions.RequestException as e: - tb = traceback.format_exc() - event_handler.append_logs(tb) - event_handler.failed(error=str(e)) - log.warn("failed to download url path from input", exc_info=True) - return event_handler.response + # Check if v is a list of URLPath instances + elif isinstance(v, list) and all( + isinstance(item, types.URLPath) for item in v + ): + input_dict[k] = [item.convert() for item in v] + except requests.exceptions.RequestException as e: + tb = traceback.format_exc() + event_handler.append_logs(tb) + event_handler.failed(error=str(e)) + log.warn("Failed to download url path from input", exc_info=True) + return event_handler.response for event in worker.predict(input_dict, poll=0.1): if should_cancel.is_set(): diff --git a/test-integration/test_integration/fixtures/file-input-project/predict.py b/test-integration/test_integration/fixtures/file-input-project/predict.py index 84d08050a4..1df8fbfca6 100644 --- a/test-integration/test_integration/fixtures/file-input-project/predict.py +++ b/test-integration/test_integration/fixtures/file-input-project/predict.py @@ -1,7 +1,10 @@ -from cog import BasePredictor, Path +from cog import BasePredictor, File class Predictor(BasePredictor): - def predict(self, path: Path) -> str: - with open(path) as f: - return f.read() + def predict(self, file: File) -> str: + content = file.read() + if isinstance(content, bytes): + # Decode bytes to str assuming UTF-8 encoding; adjust if needed + content = content.decode('utf-8') + return content diff --git a/test-integration/test_integration/fixtures/file-input-project/test.txt b/test-integration/test_integration/fixtures/file-input-project/test.txt new file mode 100644 index 0000000000..8e27be7d61 --- /dev/null +++ b/test-integration/test_integration/fixtures/file-input-project/test.txt @@ -0,0 +1 @@ +text diff --git a/test-integration/test_integration/fixtures/file-list-input-project/cog.yaml b/test-integration/test_integration/fixtures/file-list-input-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/file-list-input-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/file-list-input-project/predict.py b/test-integration/test_integration/fixtures/file-list-input-project/predict.py new file mode 100644 index 0000000000..c8f0f86c8c --- /dev/null +++ b/test-integration/test_integration/fixtures/file-list-input-project/predict.py @@ -0,0 +1,14 @@ +from cog import BasePredictor, File + + +class Predictor(BasePredictor): + def predict(self, files: list[File]) -> str: + output_parts = [] # Use a list to collect file contents + for f in files: + # Assuming file content is in bytes, decode to str before appending + content = f.read() + if isinstance(content, bytes): + # Decode bytes to str assuming UTF-8 encoding; adjust if needed + content = content.decode('utf-8') + output_parts.append(content) + return "\n\n".join(output_parts) diff --git a/test-integration/test_integration/fixtures/file-list-output-project/cog.yaml b/test-integration/test_integration/fixtures/path-input-project/cog.yaml similarity index 100% rename from test-integration/test_integration/fixtures/file-list-output-project/cog.yaml rename to test-integration/test_integration/fixtures/path-input-project/cog.yaml diff --git a/test-integration/test_integration/fixtures/path-input-project/predict.py b/test-integration/test_integration/fixtures/path-input-project/predict.py new file mode 100644 index 0000000000..84d08050a4 --- /dev/null +++ b/test-integration/test_integration/fixtures/path-input-project/predict.py @@ -0,0 +1,7 @@ +from cog import BasePredictor, Path + + +class Predictor(BasePredictor): + def predict(self, path: Path) -> str: + with open(path) as f: + return f.read() diff --git a/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml b/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/path-list-input-project/predict.py b/test-integration/test_integration/fixtures/path-list-input-project/predict.py new file mode 100644 index 0000000000..dd2b8c8e77 --- /dev/null +++ b/test-integration/test_integration/fixtures/path-list-input-project/predict.py @@ -0,0 +1,9 @@ +from cog import BasePredictor, Path + +class Predictor(BasePredictor): + def predict(self, paths: list[Path]) -> str: + output_parts = [] # Use a list to collect file contents + for path in paths: + with open(path) as f: + output_parts.append(f.read()) + return "".join(output_parts) diff --git a/test-integration/test_integration/fixtures/path-list-output-project/cog.yaml b/test-integration/test_integration/fixtures/path-list-output-project/cog.yaml new file mode 100644 index 0000000000..ce622845eb --- /dev/null +++ b/test-integration/test_integration/fixtures/path-list-output-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.8" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/file-list-output-project/predict.py b/test-integration/test_integration/fixtures/path-list-output-project/predict.py similarity index 100% rename from test-integration/test_integration/fixtures/file-list-output-project/predict.py rename to test-integration/test_integration/fixtures/path-list-output-project/predict.py diff --git a/test-integration/test_integration/fixtures/file-output-project/cog.yaml b/test-integration/test_integration/fixtures/path-output-project/cog.yaml similarity index 100% rename from test-integration/test_integration/fixtures/file-output-project/cog.yaml rename to test-integration/test_integration/fixtures/path-output-project/cog.yaml diff --git a/test-integration/test_integration/fixtures/file-output-project/predict.py b/test-integration/test_integration/fixtures/path-output-project/predict.py similarity index 100% rename from test-integration/test_integration/fixtures/file-output-project/predict.py rename to test-integration/test_integration/fixtures/path-output-project/predict.py diff --git a/test-integration/test_integration/fixtures/file-project/cog.yaml b/test-integration/test_integration/fixtures/path-project/cog.yaml similarity index 100% rename from test-integration/test_integration/fixtures/file-project/cog.yaml rename to test-integration/test_integration/fixtures/path-project/cog.yaml diff --git a/test-integration/test_integration/fixtures/file-project/predict.py b/test-integration/test_integration/fixtures/path-project/predict.py similarity index 100% rename from test-integration/test_integration/fixtures/file-project/predict.py rename to test-integration/test_integration/fixtures/path-project/predict.py diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index 25c85572e8..a04b5c3418 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -52,7 +52,7 @@ def predict(self, text: str) -> str: def test_build_with_model(docker_image): - project_dir = Path(__file__).parent / "fixtures/file-project" + project_dir = Path(__file__).parent / "fixtures/path-project" subprocess.run( ["cog", "build", "-t", docker_image], cwd=project_dir, diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 311be4ad5a..91d3f4d3ca 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -31,7 +31,7 @@ def test_predict_takes_int_inputs_and_returns_ints_to_stdout(): def test_predict_takes_file_inputs(tmpdir_factory): - project_dir = Path(__file__).parent / "fixtures/file-input-project" + project_dir = Path(__file__).parent / "fixtures/path-input-project" out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) with open(out_dir / "input.txt", "w") as fh: @@ -46,7 +46,7 @@ def test_predict_takes_file_inputs(tmpdir_factory): def test_predict_writes_files_to_files(tmpdir_factory): - project_dir = Path(__file__).parent / "fixtures/file-output-project" + project_dir = Path(__file__).parent / "fixtures/path-output-project" out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) result = subprocess.run( @@ -61,7 +61,7 @@ def test_predict_writes_files_to_files(tmpdir_factory): def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory): - project_dir = Path(__file__).parent / "fixtures/file-output-project" + project_dir = Path(__file__).parent / "fixtures/path-output-project" out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) result = subprocess.run( @@ -76,7 +76,7 @@ def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory): def test_predict_writes_multiple_files_to_files(tmpdir_factory): - project_dir = Path(__file__).parent / "fixtures/file-list-output-project" + project_dir = Path(__file__).parent / "fixtures/path-list-output-project" out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) result = subprocess.run( @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory): capture_output=True, ) assert result.stdout.decode() == "hello default 20 world jpg foo 6\n" + + +def test_predict_path_list_input(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/path-list-input-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + with open(out_dir / "1.txt", "w") as fh: + fh.write("test1") + with open(out_dir / "2.txt", "w") as fh: + fh.write("test2") + cmd = ["cog", "predict", "-i", "paths=@1.txt", "-i", "paths=@2.txt"] + + result = subprocess.run( + cmd, + cwd=out_dir, + check=True, + capture_output=True, + ) + stdout = result.stdout.decode() + assert "test1" in stdout + assert "test2" in stdout