Skip to content

Commit

Permalink
Add new connection type: OpenShift
Browse files Browse the repository at this point in the history
  • Loading branch information
luis5tb committed Oct 21, 2024
1 parent f08e576 commit ff8304b
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scripts/json_schema/gen_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_required(self, obj):
from promptflow._sdk.schemas._connection import AzureOpenAIConnectionSchema, OpenAIConnectionSchema, \
QdrantConnectionSchema, CognitiveSearchConnectionSchema, SerpConnectionSchema, AzureContentSafetyConnectionSchema, \
FormRecognizerConnectionSchema, CustomConnectionSchema, WeaviateConnectionSchema, ServerlessConnectionSchema, \
CustomStrongTypeConnectionSchema, AzureAIServicesConnectionSchema
OpenShiftConnectionSchema, CustomStrongTypeConnectionSchema, AzureAIServicesConnectionSchema
from promptflow._sdk.schemas._run import RunSchema
from promptflow._sdk.schemas._flow import FlowSchema, FlexFlowSchema

Expand All @@ -168,7 +168,7 @@ def dump_json(file_name, dct):
args.output_file = ["Run", "Flow", "AzureOpenAIConnection", "OpenAIConnection", "QdrantConnection",
"CognitiveSearchConnection", "SerpConnection", "AzureContentSafetyConnection",
"FormRecognizerConnection", "CustomConnection", "WeaviateConnection", "ServerlessConnection",
"CustomStrongTypeConnection", "AzureAIServicesConnection"]
"OpenShiftConnection", "CustomStrongTypeConnection", "AzureAIServicesConnection"]

# Special case for Flow and EagerFlow
if "Flow" in args.output_file:
Expand Down
1 change: 1 addition & 0 deletions src/promptflow-core/promptflow/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class ConnectionType(str, Enum):
WEAVIATE = "Weaviate"
SERVERLESS = "Serverless"
CUSTOM = "Custom"
OPENSHIFT = "OpenShift"


class CustomStrongTypeConnectionConfigs:
Expand Down
2 changes: 2 additions & 0 deletions src/promptflow-core/promptflow/_core/data/tool.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"FunctionStr",
"FormRecognizerConnection",
"ServerlessConnection",
"OpenShiftConnection",
"AzureAIServicesConnection",
"FilePath",
"Image",
Expand Down Expand Up @@ -246,6 +247,7 @@
"function_str",
"FormRecognizerConnection",
"ServerlessConnection",
"OpenShiftConnection",
"AzureAIServicesConnection",
"file_path",
"image",
Expand Down
2 changes: 2 additions & 0 deletions src/promptflow-core/promptflow/connections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
OpenAIConnection,
SerpConnection,
ServerlessConnection,
OpenShiftConnection,
_Connection,
)
from promptflow.core._connection_provider._connection_provider import ConnectionProvider
Expand All @@ -41,6 +42,7 @@ class BingConnection:
"CustomConnection",
"CustomStrongTypeConnection",
"ServerlessConnection",
"OpenShiftConnection",
"AzureAIServicesConnection",
"ConnectionProvider",
]
Expand Down
29 changes: 29 additions & 0 deletions src/promptflow-core/promptflow/core/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,35 @@ def api_base(self, value):
self.configs["api_base"] = value


class OpenShiftConnection(_StrongTypeConnection):
"""OpenShift connection.
:param token: The api token.
:type token: str
:param endpoint: The api endpoint.
:type endpoint: str
:param name: Connection name.
:type name: str
"""

TYPE = ConnectionType.OPENSHIFT.value

def __init__(self, token: str, endpoint: str, **kwargs):
secrets = {"token": token}
configs = {"endpoint": endpoint}
super().__init__(secrets=secrets, configs=configs, **kwargs)

@property
def endpoint(self):
"""Return the connection api endpoint."""
return self.configs.get("endpoint")

@endpoint.setter
def endpoint(self, value):
"""Set the connection api endpoint."""
self.configs["endpoint"] = value


class SerpConnection(_StrongTypeConnection):
"""Serp connection.
Expand Down
2 changes: 2 additions & 0 deletions src/promptflow-devkit/promptflow/_sdk/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FormRecognizerConnection,
CustomStrongTypeConnection,
ServerlessConnection,
OpenShiftConnection,
)
from ._run import Run
from ._validation import ValidationResult
Expand All @@ -35,6 +36,7 @@
"WeaviateConnection",
"FormRecognizerConnection",
"ServerlessConnection",
"OpenShiftConnection",
# endregion
# region Run
"Run",
Expand Down
11 changes: 11 additions & 0 deletions src/promptflow-devkit/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QdrantConnectionSchema,
SerpConnectionSchema,
ServerlessConnectionSchema,
OpenShiftConnectionSchema,
WeaviateConnectionSchema,
)
from promptflow._utils.logger_utils import LoggerFactory
Expand All @@ -58,6 +59,7 @@
from promptflow.core._connection import QdrantConnection as _CoreQdrantConnection
from promptflow.core._connection import SerpConnection as _CoreSerpConnection
from promptflow.core._connection import ServerlessConnection as _CoreServerlessConnection
from promptflow.core._connection import OpenShiftConnection as _CoreOpenShiftConnection
from promptflow.core._connection import WeaviateConnection as _CoreWeaviateConnection
from promptflow.core._connection import _Connection as _CoreConnection
from promptflow.core._connection import _StrongTypeConnection as _CoreStrongTypeConnection
Expand Down Expand Up @@ -315,6 +317,15 @@ def _get_schema_cls(cls):
return ServerlessConnectionSchema


class OpenShiftConnection(_CoreOpenShiftConnection, _StrongTypeConnection):
__doc__ = _CoreOpenShiftConnection.__doc__
DATA_CLASS = _CoreOpenShiftConnection

@classmethod
def _get_schema_cls(cls):
return OpenShiftConnectionSchema


class SerpConnection(_CoreSerpConnection, _StrongTypeConnection):
__doc__ = _CoreSerpConnection.__doc__
DATA_CLASS = _CoreSerpConnection
Expand Down
6 changes: 6 additions & 0 deletions src/promptflow-devkit/promptflow/_sdk/schemas/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class ServerlessConnectionSchema(ConnectionSchema):
api_base = fields.Str(required=True)


class OpenShiftConnectionSchema(ConnectionSchema):
type = StringTransformedEnum(allowed_values=camel_to_snake(ConnectionType.OPENSHIFT), required=True)
token = fields.Str(required=True)
endpoint = fields.Str(required=True)


class EmbeddingStoreConnectionSchema(ConnectionSchema):
module = fields.Str(dump_default="promptflow_vectordb.connections")
api_key = fields.Str(required=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
QdrantConnection,
SerpConnection,
ServerlessConnection,
OpenShiftConnection,
WeaviateConnection,
_Connection,
)
Expand Down Expand Up @@ -222,6 +223,19 @@ class TestConnection:
"type": "serverless",
},
),
(
"openshift_connection.yaml",
OpenShiftConnection,
{
"name": "my_openshift_connection",
"token": "<to-be-replaced>",
"endpoint": "https://mock.api.base",
},
{
"module": "promptflow.connections",
"type": "openshift",
},
),
(
"azure_ai_services_connection.yaml",
AzureAIServicesConnection,
Expand Down
8 changes: 8 additions & 0 deletions src/promptflow-tools/connections.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,13 @@
"api_base": "serverless-embedding-endpoint-url"
},
"module": "promptflow.connections"
},
"openshift_connection": {
"type": "OpenShiftConnection",
"value": {
"endpoint": "openshift-endpoint",
"token": "openshift-token"
},
"module": "promtflow.connections"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
$schema: https://azuremlschemas.azureedge.net/promptflow/latest/OpenShiftConnection.schema.json
name: my_openshift_connection
type: openshift
token: "<to-be-replaced>"
endpoint: "https://mock.api.base"

0 comments on commit ff8304b

Please sign in to comment.