Skip to content

Allow Union and List input types #1311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2023

Conversation

mattt
Copy link
Contributor

@mattt mattt commented Sep 21, 2023

Alternative to #968 #1292

This PR extends Cog to support List and Union type inputs.

For example, given the following predictor:

from cog import BasePredictor, Path

from typing import List, Union

class Predictor(BasePredictor):
    def predict(self, args: Union[int, List[int]]) -> int:
        if isinstance(args, int):
            return args
        else:
            return sum(args)

The following assertions are satisfied:

resp = client.post("/predictions", json={"input": {"args": 123}})
assert resp.status_code == 200
assert resp.json()["output"] == 123

resp = client.post("/predictions", json={"input": {"args": [1, 2, 3]}})
assert resp.status_code == 200
assert resp.json()["output"] == 6

resp = client.post("/predictions", json={"input": {"args": "abc"}})
assert resp.status_code == 422

resp = client.post("/predictions", json={"input": {"args": ["a", "b", "c"]}})
assert resp.status_code == 422

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
@mattt mattt requested a review from bfirsh September 21, 2023 19:53
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
Copy link
Member

@bfirsh bfirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marvellous. Could do with some docs, but we can fix up later.

I wonder how this interacts with Replicate?

@mattt
Copy link
Contributor Author

mattt commented Sep 22, 2023

@bfirsh Union[str, List[str]] is expressed in OpenAPI as {"anyOf": [{"type": "string"}, {"type": "array", "items": "string"}}]. Currently, Replicate renders inputs of that schema as a single text field. We can eventually support multiple inputs, but this seems sufficient for now, since users can send multiple inputs via the API.

@mattt mattt merged commit fd72b03 into main Sep 22, 2023
@mattt mattt deleted the mattt/accept-string-or-array-of-strings-input branch September 22, 2023 12:54
mattt added a commit that referenced this pull request Sep 22, 2023
Squashed commit of the following:

commit fd72b03
Author: Mattt <mattt@replicate.com>
Date:   Fri Sep 22 05:54:06 2023 -0700

    Allow `Union` and `List` input types (#1311)

    * Allow Union and List input types

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

    * Update TypeError messages to note support for Union and List types

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

    ---------

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

commit 6123056
Author: Mattt <mattt@replicate.com>
Date:   Thu Sep 21 13:21:00 2023 -0700

    Add support for text/markdown / .md files (#1310)

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

commit aa7aa77
Author: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Date:   Thu Sep 21 13:20:46 2023 -0700

    Bump gotest.tools/gotestsum from 1.10.1 to 1.11.0 (#1306)

    Bumps [gotest.tools/gotestsum](https://github.com/gotestyourself/gotestsum) from 1.10.1 to 1.11.0.
    - [Release notes](https://github.com/gotestyourself/gotestsum/releases)
    - [Commits](gotestyourself/gotestsum@v1.10.1...v1.11.0)

    ---
    updated-dependencies:
    - dependency-name: gotest.tools/gotestsum
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...

    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

commit f806bf1
Author: technillogue <technillogue@gmail.com>
Date:   Wed Sep 20 12:17:24 2023 -0400

    ||true (#1307)

    Signed-off-by: technillogue <technillogue@gmail.com>

commit 8487137
Author: Nick Stenning <nick@whiteink.com>
Date:   Thu Sep 14 14:19:09 2023 -0700

    Override sys.argv while importing predictor module (#1304)

    In production, any sys.argv should not be exposed to user code. Partly
    because this might leak information about the production environment,
    but primarily because user code often has its own argument parsing code
    which gets confused when it sees our arguments to cog/server/http.py.

    This uses mock.patch to ensure that sys.argv contains just sys.argv[0]
    when user code is executing.

    Even though we (Cog) have already done all argument parsing long before
    this executes, we also put the original value back once we're done
    running user code.

    Signed-off-by: Nick Stenning <nick@whiteink.com>

commit 0e3af7b
Author: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com>
Date:   Tue Sep 12 03:48:07 2023 -0700

    add technillogue as a contributor for code (#1256)

    * update README.md [skip ci]

    * update .all-contributorsrc [skip ci]

    ---------

    Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com>

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
mattt added a commit that referenced this pull request Sep 22, 2023
Signed-off-by: Mattt Zmuda <mattt@replicate.com>

Squashed commit of the following:

commit fd72b03
Author: Mattt <mattt@replicate.com>
Date:   Fri Sep 22 05:54:06 2023 -0700

    Allow `Union` and `List` input types (#1311)

    * Allow Union and List input types

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

    * Update TypeError messages to note support for Union and List types

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

    ---------

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

commit 6123056
Author: Mattt <mattt@replicate.com>
Date:   Thu Sep 21 13:21:00 2023 -0700

    Add support for text/markdown / .md files (#1310)

    Signed-off-by: Mattt Zmuda <mattt@replicate.com>

commit aa7aa77
Author: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Date:   Thu Sep 21 13:20:46 2023 -0700

    Bump gotest.tools/gotestsum from 1.10.1 to 1.11.0 (#1306)

    Bumps [gotest.tools/gotestsum](https://github.com/gotestyourself/gotestsum) from 1.10.1 to 1.11.0.
    - [Release notes](https://github.com/gotestyourself/gotestsum/releases)
    - [Commits](gotestyourself/gotestsum@v1.10.1...v1.11.0)

    ---
    updated-dependencies:
    - dependency-name: gotest.tools/gotestsum
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...

    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

commit f806bf1
Author: technillogue <technillogue@gmail.com>
Date:   Wed Sep 20 12:17:24 2023 -0400

    ||true (#1307)

    Signed-off-by: technillogue <technillogue@gmail.com>

commit 8487137
Author: Nick Stenning <nick@whiteink.com>
Date:   Thu Sep 14 14:19:09 2023 -0700

    Override sys.argv while importing predictor module (#1304)

    In production, any sys.argv should not be exposed to user code. Partly
    because this might leak information about the production environment,
    but primarily because user code often has its own argument parsing code
    which gets confused when it sees our arguments to cog/server/http.py.

    This uses mock.patch to ensure that sys.argv contains just sys.argv[0]
    when user code is executing.

    Even though we (Cog) have already done all argument parsing long before
    this executes, we also put the original value back once we're done
    running user code.

    Signed-off-by: Nick Stenning <nick@whiteink.com>

commit 0e3af7b
Author: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com>
Date:   Tue Sep 12 03:48:07 2023 -0700

    add technillogue as a contributor for code (#1256)

    * update README.md [skip ci]

    * update .all-contributorsrc [skip ci]

    ---------

    Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com>

Signed-off-by: Mattt Zmuda <mattt@replicate.com>
@peter65374
Copy link

peter65374 commented Oct 20, 2023

only List[int] is supported so far? Does input type support List[Path] or List[str]? I am expecting to input multiple images to process as below

def predict(
        self,
        target_image: Path = Input(
            description="Faces in target image would be changed.",
        ),
        model_faces: list[Path] = Input(
            description="Model Faces would be swapped into target image.", 
            default=None,
        ),
        inference_mode: str = Input(
            default="swap",
            choices=["swap", "detect"],
            description="Face swap mode or detection mode. Default is swap.",
        ),
    ) -> Union[List[tuple], List[Path]]:
        """Run a single prediction on the model"""
        # Check if inference_mode is detect or swap
        if inference_mode == "detect":
            detect_mode = True
        else: 
            detect_mode = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants