Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 4
"modification": 5
}
129 changes: 128 additions & 1 deletion sdks/python/apache_beam/ml/rag/ingestion/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -37,6 +38,73 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class AlloyDBLanguageConnectorConfig:
"""Configuration options for AlloyDB Java language connector.

Contains all parameters needed to configure a connection using the AlloyDB
Java connector via JDBC. For details see
https://github.com/GoogleCloudPlatform/alloydb-java-connector/blob/main/docs/jdbc.md

Attributes:
database_name: Name of the database to connect to.
instance_name: Fullly qualified instance. Format:
'projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances
/<INSTANCE>'
ip_type: IP type to use for connection. Either 'PRIVATE' (default),
'PUBLIC' 'PSC.
enable_iam_auth: Whether to enable IAM authentication. Default is False
target_principal: Optional service account to impersonate for
connection.
delegates: Optional comma-separated list of service accounts for
delegated impersonation.
admin_service_endpoint: Optional custom API service endpoint.
quota_project: Optional project ID for quota and billing.
"""
database_name: str
instance_name: str
ip_type: str = "PRIVATE"
enable_iam_auth: bool = False
target_principal: Optional[str] = None
delegates: Optional[List[str]] = None
admin_service_endpoint: Optional[str] = None
quota_project: Optional[str] = None

def to_jdbc_url(self) -> str:
"""Convert options to a properly formatted JDBC URL.

Returns:
JDBC URL string configured with all options.
"""
# Base URL with database name
url = f"jdbc:postgresql:///{self.database_name}?"

# Add required properties
properties = {
"socketFactory": "com.google.cloud.alloydb.SocketFactory",
"alloydbInstanceName": self.instance_name,
"alloydbIpType": self.ip_type
}

if self.enable_iam_auth:
properties["alloydbEnableIAMAuth"] = "true"

if self.target_principal:
properties["alloydbTargetPrincipal"] = self.target_principal

if self.delegates:
properties["alloydbDelegates"] = ",".join(self.delegates)

if self.admin_service_endpoint:
properties["alloydbAdminServiceEndpoint"] = self.admin_service_endpoint

if self.quota_project:
properties["alloydbQuotaProject"] = self.quota_project

property_string = "&".join(f"{k}={v}" for k, v in properties.items())
return url + property_string


@dataclass
class AlloyDBConnectionConfig:
"""Configuration for AlloyDB database connection.
Expand All @@ -58,6 +126,10 @@ class AlloyDBConnectionConfig:
max_connections: Optional number of connections in the pool.
Use negative for no limit.
write_batch_size: Optional write batch size for bulk operations.
additional_jdbc_args: Additional arguments that will be passed to
WriteToJdbc. These may include 'driver_jars', 'expansion_service',
'classpath', etc. See full set of args at
:class:`~apache_beam.io.jdbc.WriteToJdbc`

Example:
>>> config = AlloyDBConnectionConfig(
Expand All @@ -76,6 +148,60 @@ class AlloyDBConnectionConfig:
autosharding: Optional[bool] = None
max_connections: Optional[int] = None
write_batch_size: Optional[int] = None
additional_jdbc_args: Dict[str, Any] = field(default_factory=dict)

@classmethod
def with_language_connector(
cls,
connector_options: AlloyDBLanguageConnectorConfig,
username: str,
password: str,
connection_properties: Optional[Dict[str, str]] = None,
connection_init_sqls: Optional[List[str]] = None,
autosharding: Optional[bool] = None,
max_connections: Optional[int] = None,
write_batch_size: Optional[int] = None) -> 'AlloyDBConnectionConfig':
"""Create AlloyDBConnectionConfig using the AlloyDB language connector.

Args:
connector_options: AlloyDB language connector configuration options.
username: Database username. For IAM auth, this should be the IAM
user email.
password: Database password. Can be empty string when using IAM
auth.
connection_properties: Additional JDBC connection properties.
connection_init_sqls: SQL statements to execute on connection.
autosharding: Enable autosharding.
max_connections: Max connections in pool.
write_batch_size: Write batch size.

Returns:
Configured AlloyDBConnectionConfig instance.

Example:
>>> options = AlloyDBLanguageConnectorConfig(
... database_name="mydb",
... instance_name="projects/my-project/locations/us-central1\
.... /clusters/my-cluster/instances/my-instance",
... ip_type="PUBLIC",
... enable_iam_auth=True
... )
"""
return cls(
jdbc_url=connector_options.to_jdbc_url(),
username=username,
password=password,
connection_properties=connection_properties,
connection_init_sqls=connection_init_sqls,
autosharding=autosharding,
max_connections=max_connections,
write_batch_size=write_batch_size,
additional_jdbc_args={
'classpath': [
"org.postgresql:postgresql:42.2.16",
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
]
})


@dataclass
Expand Down Expand Up @@ -713,4 +839,5 @@ def expand(self, pcoll: beam.PCollection[Chunk]):
connection_init_sqls,
autosharding=self.config.connection_config.autosharding,
max_connections=self.config.connection_config.max_connections,
write_batch_size=self.config.connection_config.write_batch_size))
write_batch_size=self.config.connection_config.write_batch_size,
**self.config.connection_config.additional_jdbc_args))
88 changes: 88 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import ReadFromJdbc
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBConnectionConfig
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpec
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpecsBuilder
Expand Down Expand Up @@ -328,6 +329,93 @@ def test_default_schema(self):
equal_to([expected_last_n]),
label=f"last_{sample_size}_check")

def test_language_connector(self):
"""Test language connector."""
self.skip_if_dataflow_runner()

connector_options = AlloyDBLanguageConnectorConfig(
database_name=self.database,
instance_name="projects/apache-beam-testing/locations/us-central1/\
clusters/testing-psc/instances/testing-psc-1",
ip_type="PSC")
connection_config = AlloyDBConnectionConfig.with_language_connector(
connector_options=connector_options,
username=self.username,
password=self.password)
config = AlloyDBVectorWriterConfig(
connection_config=connection_config, table_name=self.default_table_name)

# Create test chunks
num_records = 150
sample_size = min(500, num_records // 2)
chunks = ChunkTestUtils.get_expected_values(0, num_records)

self.write_test_pipeline.not_use_test_runner_api = True

with self.write_test_pipeline as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

self.read_test_pipeline.not_use_test_runner_api = True
read_query = f"""
SELECT
CAST(id AS VARCHAR(255)),
CAST(content AS VARCHAR(255)),
CAST(embedding AS text),
CAST(metadata AS text)
FROM {self.default_table_name}
"""

with self.read_test_pipeline as p:
rows = (
p
| ReadFromJdbc(
table_name=self.default_table_name,
driver_class_name="org.postgresql.Driver",
jdbc_url=connector_options.to_jdbc_url(),
username=self.username,
password=self.password,
query=read_query,
classpath=[
"org.postgresql:postgresql:42.2.16",
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
]))

count_result = rows | "Count All" >> beam.combiners.Count.Globally()
assert_that(count_result, equal_to([num_records]), label='count_check')

chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk))
chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(HashingFn())
assert_that(
chunk_hashes,
equal_to([generate_expected_hash(num_records)]),
label='hash_check')

# Sample validation
first_n = (
chunks
| "Key on Index" >> beam.Map(key_on_id)
| f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0], reverse=True)
| "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs]))
expected_first_n = ChunkTestUtils.get_expected_values(0, sample_size)
assert_that(
first_n,
equal_to([expected_first_n]),
label=f"first_{sample_size}_check")

last_n = (
chunks
| "Key on Index 2" >> beam.Map(key_on_id)
| f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0])
| "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs]))
expected_last_n = ChunkTestUtils.get_expected_values(
num_records - sample_size, num_records)[::-1]
assert_that(
last_n,
equal_to([expected_last_n]),
label=f"last_{sample_size}_check")

def test_custom_specs(self):
"""Test custom specifications for ID, embedding, and content."""
self.skip_if_dataflow_runner()
Expand Down
Loading