From 099f8510176d77da84d3c48421bfd1111681fc53 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 9 Sep 2024 09:48:56 -0700 Subject: [PATCH] Initial support for Python UDFs (#736) --- .github/workflows/ci.yml | 6 +- Cargo.lock | 106 +++++++ Cargo.toml | 3 +- crates/arroyo-api/Cargo.toml | 1 + .../migrations/V24__add_udf_language.sql | 3 + crates/arroyo-api/queries/api_queries.sql | 17 +- .../V2__add_udf_language.sql | 27 ++ crates/arroyo-api/src/lib.rs | 1 + crates/arroyo-api/src/pipelines.rs | 93 ++++--- crates/arroyo-api/src/udfs.rs | 137 +++++---- crates/arroyo-datastream/src/logical.rs | 77 ++++- crates/arroyo-operator/Cargo.toml | 1 + crates/arroyo-operator/src/operator.rs | 11 +- crates/arroyo-planner/Cargo.toml | 1 + crates/arroyo-planner/src/lib.rs | 39 ++- crates/arroyo-rpc/proto/api.proto | 14 +- crates/arroyo-rpc/src/api_types/udfs.rs | 30 +- .../arroyo-udf/arroyo-udf-common/src/parse.rs | 7 + .../arroyo-udf/arroyo-udf-python/Cargo.toml | 13 + .../arroyo-udf-python/python/arroyo_udf.py | 13 + .../arroyo-udf-python/src/interpreter.rs | 136 +++++++++ .../arroyo-udf/arroyo-udf-python/src/lib.rs | 262 ++++++++++++++++++ .../arroyo-udf-python/src/pyarrow.rs | 258 +++++++++++++++++ .../arroyo-udf-python/src/threaded.rs | 191 +++++++++++++ crates/arroyo-worker/src/lib.rs | 10 + crates/arroyo/src/main.rs | 10 + webui/src/gen/api-types.ts | 6 + webui/src/lib/data_fetching.ts | 22 +- webui/src/routes/pipelines/CreatePipeline.tsx | 6 + webui/src/routes/udfs/GlobalizeModal.tsx | 2 +- webui/src/routes/udfs/UdfEditTab.tsx | 1 + webui/src/routes/udfs/UdfEditor.tsx | 4 +- webui/src/routes/udfs/UdfLabel.tsx | 4 +- webui/src/routes/udfs/UdfsResourceTab.tsx | 17 +- webui/src/udf_state.ts | 45 ++- 35 files changed, 1433 insertions(+), 141 deletions(-) create mode 100644 crates/arroyo-api/migrations/V24__add_udf_language.sql create mode 100644 crates/arroyo-api/sqlite_migrations/V2__add_udf_language.sql create mode 100644 crates/arroyo-udf/arroyo-udf-python/Cargo.toml create mode 100644 crates/arroyo-udf/arroyo-udf-python/python/arroyo_udf.py create mode 100644 crates/arroyo-udf/arroyo-udf-python/src/interpreter.rs create mode 100644 crates/arroyo-udf/arroyo-udf-python/src/lib.rs create mode 100644 crates/arroyo-udf/arroyo-udf-python/src/pyarrow.rs create mode 100644 crates/arroyo-udf/arroyo-udf-python/src/threaded.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ff987c86a..2068f4051 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,9 +35,13 @@ jobs: uses: actions-rs/toolchain@v1 with: toolchain: stable - override: true + override: true - name: Check Formatting run: cargo fmt -- --check + - uses: actions/setup-python@v5 + name: Setup Python + with: + python-version: '3.12' - name: Setup pnpm uses: pnpm/action-setup@v4 with: diff --git a/Cargo.lock b/Cargo.lock index 128befeee..dcfcc2183 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,7 @@ dependencies = [ "arroyo-state", "arroyo-types", "arroyo-udf-host", + "arroyo-udf-python", "async-trait", "axum", "axum-extra", @@ -708,6 +709,7 @@ dependencies = [ "arroyo-storage", "arroyo-types", "arroyo-udf-host", + "arroyo-udf-python", "async-trait", "bincode", "datafusion", @@ -831,6 +833,7 @@ dependencies = [ "arroyo-storage", "arroyo-types", "arroyo-udf-host", + "arroyo-udf-python", "async-ffi", "async-stream", "async-trait", @@ -1081,6 +1084,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "arroyo-udf-python" +version = "0.2.0" +dependencies = [ + "anyhow", + "arrow", + "arroyo-udf-common", + "datafusion", + "itertools 0.13.0", + "pyo3", + "tokio", +] + [[package]] name = "arroyo-worker" version = "0.12.0-dev" @@ -5043,6 +5059,12 @@ dependencies = [ "serde", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "inlinable_string" version = "0.1.15" @@ -5695,6 +5717,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "miette" version = "5.10.0" @@ -7017,6 +7048,69 @@ dependencies = [ "cc", ] +[[package]] +name = "pyo3" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot 0.12.3", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.72", +] + [[package]] name = "quad-rand" version = "0.2.1" @@ -8618,6 +8712,12 @@ dependencies = [ "libc", ] +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + [[package]] name = "tempfile" version = "3.11.0" @@ -9503,6 +9603,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/Cargo.toml b/Cargo.toml index 5fc93d2e0..da67916c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "crates/arroyo-udf/arroyo-udf-plugin", "crates/arroyo-udf/arroyo-udf-host", "crates/arroyo-udf/arroyo-udf-macros", + "crates/arroyo-udf/arroyo-udf-python", "crates/arroyo-worker", "crates/copy-artifacts", "crates/integ", @@ -76,4 +77,4 @@ datafusion-physical-expr = {git = 'https://github.com/ArroyoSystems/arrow-datafu datafusion-physical-plan = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '40.0.0/arroyo'} datafusion-proto = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '40.0.0/arroyo'} cornucopia_async = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" } -cornucopia = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" } \ No newline at end of file +cornucopia = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" } diff --git a/crates/arroyo-api/Cargo.toml b/crates/arroyo-api/Cargo.toml index 9efe86693..edd471038 100644 --- a/crates/arroyo-api/Cargo.toml +++ b/crates/arroyo-api/Cargo.toml @@ -18,6 +18,7 @@ arroyo-datastream = { path = "../arroyo-datastream" } arroyo-state = { path = "../arroyo-state" } arroyo-formats = { path = "../arroyo-formats" } arroyo-udf-host = { path = "../arroyo-udf/arroyo-udf-host" } +arroyo-udf-python = { path = "../arroyo-udf/arroyo-udf-python" } tonic = { workspace = true } tonic-reflection = { workspace = true } diff --git a/crates/arroyo-api/migrations/V24__add_udf_language.sql b/crates/arroyo-api/migrations/V24__add_udf_language.sql new file mode 100644 index 000000000..e916f0113 --- /dev/null +++ b/crates/arroyo-api/migrations/V24__add_udf_language.sql @@ -0,0 +1,3 @@ +ALTER TABLE udfs ADD COLUMN language VARCHAR(15) NOT NULL DEFAULT 'rust'; +ALTER TABLE udfs ALTER COLUMN dylib_url DROP NOT NULL; +ALTER TABLE udfs ALTER COLUMN dylib_url DROP DEFAULT; diff --git a/crates/arroyo-api/queries/api_queries.sql b/crates/arroyo-api/queries/api_queries.sql index eb36fc5d1..8f1ed3a75 100644 --- a/crates/arroyo-api/queries/api_queries.sql +++ b/crates/arroyo-api/queries/api_queries.sql @@ -281,26 +281,27 @@ WHERE job_configs.organization_id = :organization_id AND job_configs.id = :job_i ORDER BY jlm.created_at DESC LIMIT cast(:limit as integer); ------------ udfs ----------------------- ---: DbUdf (description?) +----------- udfs ----------------------- ---! create_udf -INSERT INTO udfs (pub_id, organization_id, created_by, prefix, name, definition, description, dylib_url) -VALUES (:pub_id, :organization_id, :created_by, :prefix, :name, :definition, :description, :dylib_url); +--: DbUdf (description?, dylib_url?) +--! create_udf (dylib_url?) +INSERT INTO udfs (pub_id, organization_id, created_by, prefix, name, language, definition, description, dylib_url) +VALUES (:pub_id, :organization_id, :created_by, :prefix, :name, :language, :definition, :description, :dylib_url); + --! get_udf: DbUdf -SELECT pub_id, prefix, name, definition, created_at, updated_at, description, dylib_url +SELECT pub_id, prefix, name, language, definition, created_at, updated_at, description, dylib_url FROM udfs WHERE organization_id = :organization_id AND pub_id = :pub_id; --! get_udf_by_name: DbUdf -SELECT pub_id, prefix, name, definition, created_at, updated_at, description, dylib_url +SELECT pub_id, prefix, name, language, definition, created_at, updated_at, description, dylib_url FROM udfs WHERE organization_id = :organization_id AND name = :name; --! get_udfs: DbUdf -SELECT pub_id, prefix, name, definition, created_at, updated_at, description, dylib_url +SELECT pub_id, prefix, name, language, definition, created_at, updated_at, description, dylib_url FROM udfs WHERE organization_id = :organization_id; diff --git a/crates/arroyo-api/sqlite_migrations/V2__add_udf_language.sql b/crates/arroyo-api/sqlite_migrations/V2__add_udf_language.sql new file mode 100644 index 000000000..6abdd9f2f --- /dev/null +++ b/crates/arroyo-api/sqlite_migrations/V2__add_udf_language.sql @@ -0,0 +1,27 @@ +CREATE TABLE udfs_new ( + pub_id TEXT PRIMARY KEY, + organization_id TEXT NOT NULL, + created_by TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + prefix TEXT, + name TEXT NOT NULL, + definition TEXT NOT NULL, + description TEXT, + dylib_url TEXT, + language VARCHAR(15) NOT NULL DEFAULT 'rust', + UNIQUE (organization_id, name) +); + +-- Copy data from the old table to the nfew table +INSERT INTO udfs_new (pub_id, organization_id, created_by, created_at, updated_at, + prefix, name, definition, description, dylib_url, language) +SELECT pub_id, organization_id, created_by, created_at, updated_at, + prefix, name, definition, description, dylib_url, 'rust' +from udfs; + +-- Drop the old table +DROP TABLE udfs; + +-- Rename the new table to the old table name +ALTER TABLE udfs_new RENAME TO udfs; diff --git a/crates/arroyo-api/src/lib.rs b/crates/arroyo-api/src/lib.rs index cb675a48c..6fecde975 100644 --- a/crates/arroyo-api/src/lib.rs +++ b/crates/arroyo-api/src/lib.rs @@ -310,6 +310,7 @@ impl IntoResponse for HttpError { ValidateUdfPost, UdfValidationResult, Udf, + UdfLanguage, UdfPost, GlobalUdf, GlobalUdfCollection, diff --git a/crates/arroyo-api/src/pipelines.rs b/crates/arroyo-api/src/pipelines.rs index 09188f6a0..2d7b011e7 100644 --- a/crates/arroyo-api/src/pipelines.rs +++ b/crates/arroyo-api/src/pipelines.rs @@ -18,13 +18,13 @@ use arroyo_rpc::api_types::pipelines::{ Job, Pipeline, PipelinePatch, PipelinePost, PipelineRestart, PreviewPost, QueryValidationResult, StopType, ValidateQueryPost, }; -use arroyo_rpc::api_types::udfs::{GlobalUdf, Udf}; +use arroyo_rpc::api_types::udfs::{GlobalUdf, Udf, UdfLanguage}; use arroyo_rpc::api_types::{JobCollection, PaginationQueryParams, PipelineCollection}; use arroyo_rpc::grpc::api::{ArrowProgram, ConnectorOp}; use arroyo_connectors::kafka::{KafkaConfig, KafkaTable, SchemaRegistry}; use arroyo_datastream::logical::{LogicalNode, LogicalProgram, OperatorName}; -use arroyo_df::{has_duplicate_udf_names, ArroyoSchemaProvider, CompiledSql, SqlConfig}; +use arroyo_df::{ArroyoSchemaProvider, CompiledSql, SqlConfig}; use arroyo_formats::ser::ArrowSerializer; use arroyo_rpc::formats::Format; use arroyo_rpc::grpc::rpc::compiler_grpc_client::CompilerGrpcClient; @@ -71,20 +71,23 @@ async fn compile_sql<'a>( .map(|u| u.into()) .collect::>(); - // error if there are duplicate local or duplicate global UDF names, - // but allow global UDFs to override local ones - - if has_duplicate_udf_names(global_udfs.iter().map(|u| &u.definition)) { - return Err(bad_request("Global UDFs have duplicate function names")); - } - - if has_duplicate_udf_names(local_udfs.iter().map(|u| &u.definition)) { - return Err(bad_request("Local UDFs have duplicate function names")); - } - for udf in global_udfs { - if let Err(e) = schema_provider.add_rust_udf(&udf.definition, &udf.dylib_url) { - warn!("Invalid global UDF {}: {}", udf.name, e); + match udf.language { + UdfLanguage::Python => { + if let Err(e) = schema_provider.add_python_udf(&udf.definition).await { + warn!("Invalid global python UDF '{}': {}", udf.name, e); + } + } + UdfLanguage::Rust => { + let Some(dylib_url) = &udf.dylib_url else { + warn!("Rust global UDF {} is not compiled", udf.name); + continue; + }; + + if let Err(e) = schema_provider.add_rust_udf(&udf.definition, dylib_url) { + warn!("Invalid global UDF {}: {}", udf.name, e); + } + } } } @@ -92,27 +95,45 @@ async fn compile_sql<'a>( let mut compiler_service: CompilerGrpcClient<_> = compiler_service().await?; for udf in local_udfs { - let parsed = ParsedUdfFile::try_parse(&udf.definition) - .map_err(|e| bad_request(format!("invalid UDF: {e}")))?; - - let url = if !validate_only { - let res = build_udf(&mut compiler_service, &udf.definition, true).await?; - - if !res.errors.is_empty() { - return Err(bad_request(format!( - "Failed to build UDF: {}", - res.errors.join("\n") - ))); + match udf.language { + UdfLanguage::Python => { + schema_provider + .add_python_udf(&udf.definition) + .await + .map_err(|e| bad_request(format!("invalid Python UDF: {:?}", e)))?; } - - res.url.expect("valid UDF does not have a URL in response") - } else { - "".to_string() - }; - - schema_provider - .add_rust_udf(&parsed.definition, &url) - .map_err(|e| bad_request(format!("Invalid UDF {}: {}", parsed.udf.name, e)))?; + UdfLanguage::Rust => { + let parsed = ParsedUdfFile::try_parse(&udf.definition) + .map_err(|e| bad_request(format!("invalid UDF: {e}")))?; + + let url = if !validate_only { + let res = build_udf( + &mut compiler_service, + &udf.definition, + UdfLanguage::Rust, + true, + ) + .await?; + + if !res.errors.is_empty() { + return Err(bad_request(format!( + "Failed to build UDF: {}", + res.errors.join("\n") + ))); + } + + res.url.expect("valid UDF does not have a URL in response") + } else { + "".to_string() + }; + + schema_provider + .add_rust_udf(&parsed.definition, &url) + .map_err(|e| { + bad_request(format!("Invalid UDF {}: {}", parsed.udf.name, e)) + })?; + } + } } } @@ -401,6 +422,8 @@ pub(crate) async fn create_pipeline_int<'a>( "job_id": job_id, "parallelism": parallelism, "has_udfs": udfs.first().map(|e| !e.definition.trim().is_empty()).unwrap_or(false), + "rust_udfs": udfs.iter().find(|e| e.language == UdfLanguage::Rust), + "python_udfs": udfs.iter().find(|e| e.language == UdfLanguage::Python), // TODO: program features "features": compiled.program.features(), }), diff --git a/crates/arroyo-api/src/udfs.rs b/crates/arroyo-api/src/udfs.rs index 2c416dd46..a3020bfcd 100644 --- a/crates/arroyo-api/src/udfs.rs +++ b/crates/arroyo-api/src/udfs.rs @@ -6,16 +6,21 @@ use crate::rest_utils::{ BearerAuth, ErrorResp, }; use crate::{compiler_service, to_micros}; -use arroyo_rpc::api_types::udfs::{GlobalUdf, UdfPost, UdfValidationResult, ValidateUdfPost}; +use arroyo_rpc::api_types::udfs::{ + GlobalUdf, UdfLanguage, UdfPost, UdfValidationResult, ValidateUdfPost, +}; use arroyo_rpc::api_types::GlobalUdfCollection; use arroyo_rpc::config::config; use arroyo_rpc::grpc::rpc::compiler_grpc_client::CompilerGrpcClient; use arroyo_rpc::grpc::rpc::{BuildUdfReq, UdfCrate}; use arroyo_rpc::public_ids::{generate_id, IdTypes}; use arroyo_udf_host::ParsedUdfFile; +use arroyo_udf_python::PythonUDF; use axum::extract::{Path, State}; use axum::Json; use axum_extra::extract::WithRejection; +use std::str::FromStr; +use std::sync::Arc; use tonic::transport::Channel; use tracing::error; @@ -37,6 +42,7 @@ impl From for GlobalUdf { updated_at: to_micros(val.updated_at), description: val.description, dylib_url: val.dylib_url, + language: UdfLanguage::from_str(&val.language).unwrap_or_default(), } } } @@ -65,7 +71,13 @@ pub async fn create_udf( // .map_err(log_and_map)?; // build udf - let build_udf_resp = build_udf(&mut compiler_service().await?, &req.definition, true).await?; + let build_udf_resp = build_udf( + &mut compiler_service().await?, + &req.definition, + req.language, + true, + ) + .await?; if !build_udf_resp.errors.is_empty() { return Err(bad_request("UDF is invalid")); @@ -74,7 +86,6 @@ pub async fn create_udf( let client = state.database.client().await?; let udf_name = build_udf_resp.name.expect("udf name not set for valid UDF"); - let udf_url = build_udf_resp.url.expect("udf URL not set for valid UDF"); // check for duplicates let pub_id = generate_id(IdTypes::Udf); @@ -85,9 +96,10 @@ pub async fn create_udf( &auth_data.user_id, &req.prefix, &udf_name, + &req.language.to_string(), &req.definition, &req.description.unwrap_or_default(), - &udf_url, + &build_udf_resp.url, ) .await .map_err(|e| map_insert_err("udf", e))?; @@ -180,56 +192,73 @@ impl From for UdfResp { pub async fn build_udf( compiler_service: &mut CompilerGrpcClient, udf_definition: &str, + language: UdfLanguage, save: bool, ) -> Result { - // use the ArroyoSchemaProvider to do some validation and to get the function name - let file = match ParsedUdfFile::try_parse(udf_definition) { - Ok(p) => p, - Err(e) => return Ok(e.into()), - }; - - let mut dependencies = file.dependencies; - let plugin_dep = if config().compiler.use_local_udf_crate { - toml::Value::Table( - [( - "path".to_string(), - toml::Value::String(LOCAL_UDF_LIB_CRATE.to_string()), - )] - .into_iter() - .collect(), - ) - } else { - toml::Value::String(PLUGIN_VERSION.to_string()) - }; - - dependencies.insert("arroyo-udf-plugin".to_string(), plugin_dep); - - let check_udfs_resp = match compiler_service - .build_udf(BuildUdfReq { - udf_crate: Some(UdfCrate { - name: file.udf.name.clone(), - definition: udf_definition.to_string(), - dependencies: dependencies.to_string(), + match language { + UdfLanguage::Python => match PythonUDF::parse(udf_definition).await { + Ok(udf) => Ok(UdfResp { + errors: vec![], + name: Some(Arc::unwrap_or_clone(udf.name)), + url: None, }), - save, - }) - .await - { - Ok(resp) => resp.into_inner(), - Err(e) => { - error!("compiler service failed to validate UDF: {}", e.message()); - return Err(internal_server_error(format!( - "Failed to validate UDF: {}", - e.message() - ))); - } - }; + Err(e) => Ok(UdfResp { + errors: vec![e.to_string()], + name: None, + url: None, + }), + }, + UdfLanguage::Rust => { + // use the arroyo-udf lib to do some validation and to get the function name + let file = match ParsedUdfFile::try_parse(udf_definition) { + Ok(p) => p, + Err(e) => return Ok(e.into()), + }; - Ok(UdfResp { - errors: check_udfs_resp.errors, - name: Some(file.udf.name), - url: check_udfs_resp.udf_path, - }) + let mut dependencies = file.dependencies; + let plugin_dep = if config().compiler.use_local_udf_crate { + toml::Value::Table( + [( + "path".to_string(), + toml::Value::String(LOCAL_UDF_LIB_CRATE.to_string()), + )] + .into_iter() + .collect(), + ) + } else { + toml::Value::String(PLUGIN_VERSION.to_string()) + }; + + dependencies.insert("arroyo-udf-plugin".to_string(), plugin_dep); + + let check_udfs_resp = match compiler_service + .build_udf(BuildUdfReq { + udf_crate: Some(UdfCrate { + name: file.udf.name.clone(), + definition: udf_definition.to_string(), + dependencies: dependencies.to_string(), + }), + save, + }) + .await + { + Ok(resp) => resp.into_inner(), + Err(e) => { + error!("compiler service failed to validate UDF: {}", e.message()); + return Err(internal_server_error(format!( + "Failed to validate UDF: {}", + e.message() + ))); + } + }; + + Ok(UdfResp { + errors: check_udfs_resp.errors, + name: Some(file.udf.name), + url: check_udfs_resp.udf_path, + }) + } + } } /// Validate UDFs @@ -245,7 +274,13 @@ pub async fn build_udf( pub async fn validate_udf( WithRejection(Json(req), _): WithRejection, ApiError>, ) -> Result, ErrorResp> { - let check_udfs_resp = build_udf(&mut compiler_service().await?, &req.definition, false).await?; + let check_udfs_resp = build_udf( + &mut compiler_service().await?, + &req.definition, + req.language, + false, + ) + .await?; Ok(Json(UdfValidationResult { udf_name: check_udfs_resp.name, diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index 555860e0e..cd3f96df8 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -5,9 +5,7 @@ use arrow_schema::DataType; use arroyo_rpc::api_types::pipelines::{PipelineEdge, PipelineGraph, PipelineNode}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api; -use arroyo_rpc::grpc::api::{ - ArrowDylibUdfConfig, ArrowProgram, ArrowProgramConfig, ConnectorOp, EdgeType, -}; +use arroyo_rpc::grpc::api::{ArrowProgram, ArrowProgramConfig, ConnectorOp, EdgeType}; use petgraph::dot::Dot; use petgraph::graph::DiGraph; use petgraph::prelude::EdgeRef; @@ -20,6 +18,7 @@ use std::collections::hash_map::DefaultHasher; use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hasher; +use std::sync::Arc; use strum::{Display, EnumString}; #[derive(Clone, Copy, Debug, Eq, PartialEq, EnumString, Display)] @@ -186,9 +185,18 @@ pub struct DylibUdfConfig { pub is_async: bool, } +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct PythonUdfConfig { + pub arg_types: Vec, + pub return_type: DataType, + pub name: Arc, + pub definition: Arc, +} + #[derive(Clone, Debug, Default)] pub struct ProgramConfig { pub udf_dylibs: HashMap, + pub python_udfs: HashMap, } #[derive(Clone, Debug, Default)] @@ -348,6 +356,7 @@ impl TryFrom for LogicalProgram { .program_config .unwrap_or_else(|| ArrowProgramConfig { udf_dylibs: HashMap::new(), + python_udfs: HashMap::new(), }) .into(); @@ -355,9 +364,9 @@ impl TryFrom for LogicalProgram { } } -impl From for ArrowDylibUdfConfig { +impl From for api::DylibUdfConfig { fn from(from: DylibUdfConfig) -> Self { - ArrowDylibUdfConfig { + api::DylibUdfConfig { dylib_path: from.dylib_path, arg_types: from .arg_types @@ -377,8 +386,8 @@ impl From for ArrowDylibUdfConfig { } } -impl From for DylibUdfConfig { - fn from(from: ArrowDylibUdfConfig) -> Self { +impl From for DylibUdfConfig { + fn from(from: api::DylibUdfConfig) -> Self { DylibUdfConfig { dylib_path: from.dylib_path, arg_types: from @@ -401,6 +410,50 @@ impl From for DylibUdfConfig { } } +impl From for PythonUdfConfig { + fn from(value: api::PythonUdfConfig) -> Self { + PythonUdfConfig { + arg_types: value + .arg_types + .iter() + .map(|t| { + DataType::try_from( + &ArrowType::decode(&mut t.as_slice()).expect("invalid arrow type"), + ) + .expect("invalid arrow type") + }) + .collect(), + return_type: DataType::try_from( + &ArrowType::decode(&mut value.return_type.as_slice()).unwrap(), + ) + .expect("invalid arrow type"), + name: Arc::new(value.name), + definition: Arc::new(value.definition), + } + } +} + +impl From for api::PythonUdfConfig { + fn from(from: PythonUdfConfig) -> Self { + api::PythonUdfConfig { + arg_types: from + .arg_types + .iter() + .map(|t| { + ArrowType::try_from(t) + .expect("unsupported data type") + .encode_to_vec() + }) + .collect(), + return_type: ArrowType::try_from(&from.return_type) + .expect("unsupported data type") + .encode_to_vec(), + name: (*from.name).clone(), + definition: (*from.definition).clone(), + } + } +} + impl From for ArrowProgramConfig { fn from(from: ProgramConfig) -> Self { ArrowProgramConfig { @@ -409,6 +462,11 @@ impl From for ArrowProgramConfig { .into_iter() .map(|(k, v)| (k, v.into())) .collect(), + python_udfs: from + .python_udfs + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), } } } @@ -421,6 +479,11 @@ impl From for ProgramConfig { .into_iter() .map(|(k, v)| (k, v.into())) .collect(), + python_udfs: from + .python_udfs + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), } } } diff --git a/crates/arroyo-operator/Cargo.toml b/crates/arroyo-operator/Cargo.toml index 293603f19..9ec5f43b1 100644 --- a/crates/arroyo-operator/Cargo.toml +++ b/crates/arroyo-operator/Cargo.toml @@ -14,6 +14,7 @@ arroyo-types = { path = "../arroyo-types" } arroyo-datastream = { path = "../arroyo-datastream" } arroyo-storage = { path = "../arroyo-storage" } arroyo-udf-host = { path = "../arroyo-udf/arroyo-udf-host" } +arroyo-udf-python = { path = "../arroyo-udf/arroyo-udf-python" } anyhow = "1.0.71" arrow = { workspace = true, features = ["ffi"] } diff --git a/crates/arroyo-operator/src/operator.rs b/crates/arroyo-operator/src/operator.rs index 179041abf..9fa432fdb 100644 --- a/crates/arroyo-operator/src/operator.rs +++ b/crates/arroyo-operator/src/operator.rs @@ -5,7 +5,7 @@ use crate::{CheckpointCounter, ControlOutcome, SourceFinishType}; use anyhow::anyhow; use arrow::array::RecordBatch; use arrow::datatypes::DataType; -use arroyo_datastream::logical::DylibUdfConfig; +use arroyo_datastream::logical::{DylibUdfConfig, PythonUdfConfig}; use arroyo_metrics::TaskCounters; use arroyo_rpc::grpc::rpc::{TableConfig, TaskCheckpointEventType}; use arroyo_rpc::{ControlMessage, ControlResp}; @@ -13,6 +13,7 @@ use arroyo_storage::StorageProvider; use arroyo_types::{ArrowMessage, CheckpointBarrier, SignalMessage, Watermark}; use arroyo_udf_host::parse::inner_type; use arroyo_udf_host::{ContainerOrLocal, LocalUdf, SyncUdfDylib, UdfDylib, UdfInterface}; +use arroyo_udf_python::PythonUDF; use async_trait::async_trait; use datafusion::common::{DataFusionError, Result as DFResult}; use datafusion::execution::FunctionRegistry; @@ -669,6 +670,14 @@ impl Registry { } } + pub async fn add_python_udf(&mut self, udf: &PythonUdfConfig) -> anyhow::Result<()> { + let udf = PythonUDF::parse(&*udf.definition).await?; + + self.udfs.insert((*udf.name).clone(), Arc::new(udf.into())); + + Ok(()) + } + pub fn get_dylib(&self, path: &str) -> Option> { self.dylibs.lock().unwrap().get(path).cloned() } diff --git a/crates/arroyo-planner/Cargo.toml b/crates/arroyo-planner/Cargo.toml index b7cf9464b..0c01debd8 100644 --- a/crates/arroyo-planner/Cargo.toml +++ b/crates/arroyo-planner/Cargo.toml @@ -13,6 +13,7 @@ arroyo-formats = { path = "../arroyo-formats" } arroyo-operator = { path = "../arroyo-operator" } arroyo-storage = { path = "../arroyo-storage" } arroyo-udf-host = { path = "../arroyo-udf/arroyo-udf-host" } +arroyo-udf-python = { path = "../arroyo-udf/arroyo-udf-python" } datafusion = { workspace = true } datafusion-proto = { workspace = true } diff --git a/crates/arroyo-planner/src/lib.rs b/crates/arroyo-planner/src/lib.rs index 442c199fc..53bf953b5 100644 --- a/crates/arroyo-planner/src/lib.rs +++ b/crates/arroyo-planner/src/lib.rs @@ -48,7 +48,7 @@ use tables::{Insert, Table}; use crate::builder::PlanToGraphVisitor; use crate::extension::sink::SinkExtension; use crate::plan::ArroyoRewriter; -use arroyo_datastream::logical::{DylibUdfConfig, ProgramConfig}; +use arroyo_datastream::logical::{DylibUdfConfig, ProgramConfig, PythonUdfConfig}; use arroyo_rpc::api_types::connections::ConnectionProfile; use datafusion::common::DataFusionError; use std::collections::HashSet; @@ -64,6 +64,7 @@ use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::TIMESTAMP_FIELD; use arroyo_udf_host::parse::{inner_type, UdfDef}; use arroyo_udf_host::ParsedUdfFile; +use arroyo_udf_python::PythonUDF; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; @@ -96,6 +97,7 @@ pub struct ArroyoSchemaProvider { pub udf_defs: HashMap, config_options: datafusion::config::ConfigOptions, pub dylib_udfs: HashMap, + pub python_udfs: HashMap, pub function_rewriters: Vec>, pub expr_planners: Vec>, } @@ -299,6 +301,38 @@ impl ArroyoSchemaProvider { Ok(parsed.udf.name) } + + pub async fn add_python_udf(&mut self, body: &str) -> anyhow::Result { + let parsed = PythonUDF::parse(body) + .await + .map_err(|e| e.context("parsing Python UDF"))?; + + let name = parsed.name.clone(); + + self.python_udfs.insert( + (*name).clone(), + PythonUdfConfig { + arg_types: parsed + .arg_types + .iter() + .map(|t| t.data_type.clone()) + .collect(), + return_type: parsed.return_type.data_type.clone(), + name: name.clone(), + definition: parsed.definition.clone(), + }, + ); + + let replaced = self + .functions + .insert((*parsed.name).clone(), Arc::new(parsed.into())); + + if replaced.is_some() { + warn!("Existing UDF '{}' is being overwritten", name); + } + + Ok((*name).clone()) + } } fn create_table(table_name: String, schema: Arc) -> Arc { @@ -590,7 +624,7 @@ pub async fn parse_and_get_arrow_program( let plan_rewrite = rewrite_plan(plan, &schema_provider)?; - debug!("Plan = {:?}", plan_rewrite); + debug!("Plan = {}", plan_rewrite.display_graphviz()); let mut metadata = SourceMetadataVisitor::new(&schema_provider); plan_rewrite.visit_with_subqueries(&mut metadata)?; @@ -646,6 +680,7 @@ pub async fn parse_and_get_arrow_program( graph, ProgramConfig { udf_dylibs: schema_provider.dylib_udfs.clone(), + python_udfs: schema_provider.python_udfs.clone(), }, ); diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 6b03b4fd1..3f657c95c 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -83,7 +83,7 @@ enum AsyncUdfOrdering { message AsyncUdfOperator { string name = 1; - ArrowDylibUdfConfig udf = 2; + DylibUdfConfig udf = 2; repeated bytes arg_exprs = 3; repeated bytes final_exprs = 4; AsyncUdfOrdering ordering = 5; @@ -241,7 +241,7 @@ message OperatorCheckpointDetail { -message ArrowDylibUdfConfig { +message DylibUdfConfig { string dylib_path = 1; repeated bytes arg_types = 2; bytes return_type = 3; @@ -249,8 +249,16 @@ message ArrowDylibUdfConfig { bool is_async = 5; } +message PythonUdfConfig { + string name = 1; + repeated bytes arg_types = 2; + bytes return_type = 3; + string definition = 4; +} + message ArrowProgramConfig { - map udf_dylibs = 1; + map udf_dylibs = 1; + map python_udfs = 2; } // Arrow diff --git a/crates/arroyo-rpc/src/api_types/udfs.rs b/crates/arroyo-rpc/src/api_types/udfs.rs index 15960ca30..f683b432a 100644 --- a/crates/arroyo-rpc/src/api_types/udfs.rs +++ b/crates/arroyo-rpc/src/api_types/udfs.rs @@ -1,16 +1,21 @@ use serde::{Deserialize, Serialize}; +use strum_macros::{Display, EnumString}; use utoipa::ToSchema; #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct Udf { pub definition: String, + #[serde(default)] + pub language: UdfLanguage, } #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct ValidateUdfPost { pub definition: String, + #[serde(default)] + pub language: UdfLanguage, } #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] @@ -20,10 +25,32 @@ pub struct UdfValidationResult { pub errors: Vec, } +#[derive( + Serialize, + Deserialize, + Copy, + Clone, + Debug, + ToSchema, + Default, + Display, + EnumString, + Eq, + PartialEq, +)] +#[serde(rename_all = "camelCase")] +pub enum UdfLanguage { + Python, + #[default] + Rust, +} + #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] #[serde(rename_all = "camelCase")] pub struct UdfPost { pub prefix: String, + #[serde(default)] + pub language: UdfLanguage, pub definition: String, pub description: Option, } @@ -34,9 +61,10 @@ pub struct GlobalUdf { pub id: String, pub prefix: String, pub name: String, + pub language: UdfLanguage, pub created_at: u64, pub updated_at: u64, pub definition: String, pub description: Option, - pub dylib_url: String, + pub dylib_url: Option, } diff --git a/crates/arroyo-udf/arroyo-udf-common/src/parse.rs b/crates/arroyo-udf/arroyo-udf-common/src/parse.rs index fbbdd2755..ec9264dd6 100644 --- a/crates/arroyo-udf/arroyo-udf-common/src/parse.rs +++ b/crates/arroyo-udf/arroyo-udf-common/src/parse.rs @@ -15,6 +15,13 @@ pub struct NullableType { } impl NullableType { + pub fn new(data_type: DataType, nullable: bool) -> Self { + Self { + data_type, + nullable, + } + } + pub fn null(data_type: DataType) -> Self { Self { data_type, diff --git a/crates/arroyo-udf/arroyo-udf-python/Cargo.toml b/crates/arroyo-udf/arroyo-udf-python/Cargo.toml new file mode 100644 index 000000000..3acf4282c --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "arroyo-udf-python" +version = "0.2.0" +edition = "2021" + +[dependencies] +arroyo-udf-common = { path = "../arroyo-udf-common" } +arrow = { workspace = true, features = ["ffi"] } +datafusion = { workspace = true } +pyo3 = { version = "0.21"} +anyhow = "1" +tokio = { version = "1", features = ["full"] } +itertools = "0.13.0" \ No newline at end of file diff --git a/crates/arroyo-udf/arroyo-udf-python/python/arroyo_udf.py b/crates/arroyo-udf/arroyo-udf-python/python/arroyo_udf.py new file mode 100644 index 000000000..d997e7900 --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/python/arroyo_udf.py @@ -0,0 +1,13 @@ +udf_functions = [] +arrow_udf_functions = [] + +def udf(func): + udf_functions.append(func) + return func + +def arrow_udf(func): + arrow_udf_functions.append(func) + return func + +def get_udfs(): + return udf_functions \ No newline at end of file diff --git a/crates/arroyo-udf/arroyo-udf-python/src/interpreter.rs b/crates/arroyo-udf/arroyo-udf-python/src/interpreter.rs new file mode 100644 index 000000000..2f90544ab --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/src/interpreter.rs @@ -0,0 +1,136 @@ +// Adapted from https://github.com/risingwavelabs/arrow-udf/tree/main/arrow-udf-python +// Copyright 2024 RisingWave Labs +// Modified in 2024 by Arroyo Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! High-level API for Python sub-interpreters. + +#[allow(deprecated)] +use pyo3::GILPool; +use pyo3::{ffi::*, prepare_freethreaded_python, PyErr, Python}; +use std::ffi::CStr; + +/// A Python sub-interpreter with its own GIL. +#[derive(Debug)] +pub struct SubInterpreter { + state: *mut PyThreadState, +} + +impl SubInterpreter { + /// Create a new sub-interpreter. + pub fn new() -> Result { + prepare_freethreaded_python(); + // XXX: import the `decimal` module in the main interpreter before creating sub-interpreters. + // otherwise it will cause `SIGABRT: pointer being freed was not allocated` + // when importing decimal in the second sub-interpreter. + Python::with_gil(|py| { + py.import_bound("decimal").unwrap(); + }); + + // reference: https://github.com/PyO3/pyo3/blob/9a36b5078989a7c07a5e880aea3c6da205585ee3/examples/sequential/tests/test.rs + let config = PyInterpreterConfig { + use_main_obmalloc: 0, + allow_fork: 0, + allow_exec: 0, + allow_threads: 0, + allow_daemon_threads: 0, + check_multi_interp_extensions: 1, + gil: PyInterpreterConfig_OWN_GIL, + }; + let mut state: *mut PyThreadState = std::ptr::null_mut(); + // FIXME: according to the documentation: + // - "the GIL must be held before calling this function" + // - "a current thread state must be set on entry" + // but we don't acquire the GIL here. + let status: PyStatus = unsafe { Py_NewInterpreterFromConfig(&mut state, &config) }; + if unsafe { PyStatus_IsError(status) } == 1 { + let msg = unsafe { CStr::from_ptr(status.err_msg) }; + return Err(anyhow::anyhow!( + "failed to create sub-interpreter: {}", + msg.to_string_lossy() + ) + .into()); + } + // release the GIL + unsafe { PyEval_SaveThread() }; + Ok(Self { state }) + } + + /// Run a closure in the sub-interpreter. + /// + /// Please note that if the return value contains any `Py` object (e.g. `PyErr`), + /// this object must be dropped in this sub-interpreter, otherwise it will cause + /// `SIGABRT: pointer being freed was not allocated`. + pub fn with_gil(&self, f: F) -> Result + where + F: for<'py> FnOnce(Python<'py>) -> Result, + { + // switch to the sub-interpreter and acquire GIL + unsafe { PyEval_RestoreThread(self.state) }; + + // Safety: the GIL is already held + // this pool is used to increment the internal GIL count of pyo3. + #[allow(deprecated)] + let pool = unsafe { GILPool::new() }; + let ret = f(pool.python()); + drop(pool); + + // release the GIL + unsafe { PyEval_SaveThread() }; + ret + } +} + +impl Drop for SubInterpreter { + fn drop(&mut self) { + unsafe { + // switch to the sub-interpreter + PyEval_RestoreThread(self.state); + // destroy the sub-interpreter + Py_EndInterpreter(self.state); + } + } +} + +/// The error type for Python sub-interpreters. +/// +/// This type is a wrapper around `anyhow::Error`. The special thing is that +/// when it comes from `PyErr`, only the error message is retained, and the +/// original type is discarded. This is to avoid the problem of `PyErr` being +/// dropped outside the sub-interpreter. +#[derive(Debug)] +pub struct PyError { + anyhow: anyhow::Error, +} + +/// Converting from `PyErr` only keeps the error message. +impl From for PyError { + fn from(err: PyErr) -> Self { + Self { + anyhow: anyhow::anyhow!(err.to_string()), + } + } +} + +impl From for PyError { + fn from(err: anyhow::Error) -> Self { + Self { anyhow: err } + } +} + +impl From for anyhow::Error { + fn from(err: PyError) -> Self { + err.anyhow + } +} diff --git a/crates/arroyo-udf/arroyo-udf-python/src/lib.rs b/crates/arroyo-udf/arroyo-udf-python/src/lib.rs new file mode 100644 index 000000000..486388928 --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/src/lib.rs @@ -0,0 +1,262 @@ +mod interpreter; +mod pyarrow; +mod threaded; + +use crate::threaded::ThreadedUdfInterpreter; +use anyhow::{anyhow, bail}; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::DataType; +use arroyo_udf_common::parse::NullableType; +use datafusion::common::Result as DFResult; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyString, PyTuple}; +use pyo3::{Bound, PyAny}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::mpsc::{Receiver, SyncSender}; +use std::sync::{Arc, Mutex}; + +const UDF_PY_LIB: &str = include_str!("../python/arroyo_udf.py"); + +#[derive(Debug)] +pub struct PythonUDF { + pub name: Arc, + pub task_tx: SyncSender>, + pub result_rx: Arc>>>, + pub definition: Arc, + pub signature: Arc, + pub arg_types: Arc>, + pub return_type: Arc, +} + +impl ScalarUDFImpl for PythonUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> DFResult { + Ok(self.return_type.data_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DFResult { + let size = args + .iter() + .map(|e| match e { + ColumnarValue::Array(a) => a.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(0); + + let args = args + .iter() + .map(|e| match e { + ColumnarValue::Array(a) => a.clone(), + ColumnarValue::Scalar(s) => Arc::new(s.to_array_of_size(size).unwrap()), + }) + .collect(); + + self.task_tx.send(args).map_err(|_| { + DataFusionError::Execution("Python UDF interpreter shut down unexpectedly".to_string()) + })?; + + let result = self + .result_rx + .lock() + .unwrap() + .recv() + .map_err(|_| { + DataFusionError::Execution( + "Python UDF interpreter shut down unexpectedly".to_string(), + ) + })? + .map_err(|e| { + DataFusionError::Execution(format!("Error in Python UDF {}: {}", self.name, e)) + })?; + + Ok(ColumnarValue::Array(result)) + } +} + +fn extract_type_info(udf: &Bound) -> anyhow::Result<(Vec, NullableType)> { + let attr = udf.getattr("__annotations__")?; + let annotations: &Bound = attr.downcast().map_err(|e| { + anyhow!( + "__annotations__ object is not a dictionary: {}", + e.to_string() + ) + })?; + + // Iterate over annotations dictionary + let (ok, err): (Vec<_>, Vec<_>) = annotations + .iter() + .map(|(k, v)| { + python_type_to_arrow( + k.downcast::().unwrap().to_str().unwrap(), + &v, + false, + ) + }) + .partition(|e| e.is_ok()); + + if !err.is_empty() { + bail!( + "Could not register Python UDF: {}", + err.into_iter() + .map(|t| t.unwrap_err().to_string()) + .collect::>() + .join(", ") + ); + } + + let mut result: Vec<_> = ok.into_iter().map(|t| t.unwrap()).collect(); + + let ret = result + .pop() + .ok_or_else(|| anyhow!("No return type defined for function"))?; + + Ok((result, ret)) +} + +impl PythonUDF { + pub async fn parse(body: impl Into) -> anyhow::Result { + ThreadedUdfInterpreter::new(Arc::new(body.into())).await + } +} + +fn python_type_to_arrow( + var_name: &str, + py_type: &Bound, + nullable: bool, +) -> anyhow::Result { + let name = py_type + .getattr("__name__") + .map_err(|e| anyhow!("Could not get name of type for argument {var_name}: {e}"))? + .downcast::() + .map_err(|_| anyhow!("Argument type was not a string"))? + .to_string(); + + if name == "Optional" { + return python_type_to_arrow( + var_name, + &py_type + .getattr("__args__") + .map_err(|_| anyhow!("Optional type does not have arguments"))? + .downcast::() + .map_err(|e| anyhow!("__args__ is not a tuple: {e}"))? + .get_item(0)?, + true, + ); + } + + let data_type = match name.as_str() { + "int" => DataType::Int64, + "float" => DataType::Float64, + "str" => DataType::Utf8, + "bool" => DataType::Boolean, + "list" => bail!("lists are not yet supported"), + other => bail!("Unsupported Python type: {}", other), + }; + + Ok(NullableType::new(data_type, nullable)) +} + +#[cfg(test)] +mod test { + use crate::PythonUDF; + use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, TypeSignature}; + use std::sync::Arc; + + #[tokio::test] + async fn test() { + let udf = r#" +from arroyo_udf import udf + +@udf +def my_add(x: int, y: float) -> float: + return float(x) + y +"#; + + let udf = PythonUDF::parse(udf).await.unwrap(); + assert_eq!(udf.name.as_str(), "my_add"); + if let datafusion::logical_expr::TypeSignature::OneOf(args) = &udf.signature.type_signature + { + let ts: Vec<_> = args + .iter() + .map(|e| { + if let TypeSignature::Exact(v) = e { + v + } else { + panic!( + "expected inner typesignature sto be exact, but found {:?}", + e + ) + } + }) + .collect(); + + use arrow::datatypes::DataType::*; + + assert_eq!( + ts, + vec![ + &vec![Int8, Float32], + &vec![Int8, Float64], + &vec![Int16, Float32], + &vec![Int16, Float64], + &vec![Int32, Float32], + &vec![Int32, Float64], + &vec![Int64, Float32], + &vec![Int64, Float64], + &vec![UInt8, Float32], + &vec![UInt8, Float64], + &vec![UInt16, Float32], + &vec![UInt16, Float64], + &vec![UInt32, Float32], + &vec![UInt32, Float64], + &vec![UInt64, Float32], + &vec![UInt64, Float64] + ] + ); + } else { + panic!("Expected oneof type signature"); + } + + assert_eq!( + udf.return_type.data_type, + arrow::datatypes::DataType::Float64 + ); + assert_eq!(udf.return_type.nullable, false); + + let data = vec![ + ColumnarValue::Array(Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3]))), + ColumnarValue::Array(Arc::new(arrow::array::Float64Array::from(vec![ + 1.0, 2.0, 3.0, + ]))), + ]; + + let result = udf.invoke(&data).unwrap(); + if let ColumnarValue::Array(a) = result { + let a = a + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.len(), 3); + assert_eq!(a.value(0), 2.0); + assert_eq!(a.value(1), 4.0); + assert_eq!(a.value(2), 6.0); + } else { + panic!("Expected array result"); + } + } +} diff --git a/crates/arroyo-udf/arroyo-udf-python/src/pyarrow.rs b/crates/arroyo-udf/arroyo-udf-python/src/pyarrow.rs new file mode 100644 index 000000000..f9b770993 --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/src/pyarrow.rs @@ -0,0 +1,258 @@ +// Original source from https://github.com/risingwavelabs/arrow-udf/tree/main/arrow-udf-python +// Copyright 2024 RisingWave Labs +// +// Modified in 2024 by Arroyo Systems +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Convert arrow array from/to python objects. + +use arrow::array::{array::*, builder::*}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::DataType; +use pyo3::{exceptions::PyTypeError, types::PyAnyMethods, IntoPy, PyObject, PyResult, Python}; +use std::sync::Arc; + +macro_rules! get_pyobject { + ($array_type: ty, $py:expr, $array:expr, $i:expr) => {{ + let array = $array.as_any().downcast_ref::<$array_type>().unwrap(); + array.value($i).into_py($py) + }}; +} + +macro_rules! build_array { + (NullBuilder, $py:expr, $pyobjects:expr) => {{ + let mut builder = NullBuilder::new(); + for pyobj in $pyobjects { + if pyobj.is_none($py) { + builder.append_null(); + } else { + builder.append_empty_value(); + } + } + Ok(Arc::new(builder.finish())) + }}; + // primitive types + ($builder_type: ty, $py:expr, $pyobjects:expr) => {{ + let mut builder = <$builder_type>::with_capacity($pyobjects.len()); + for pyobj in $pyobjects { + if pyobj.is_none($py) { + builder.append_null(); + } else { + builder.append_value(pyobj.extract($py)?); + } + } + Ok(Arc::new(builder.finish())) + }}; + // string and bytea + ($builder_type: ty, $elem_type: ty, $py:expr, $pyobjects:expr) => {{ + let mut builder = <$builder_type>::with_capacity($pyobjects.len(), 1024); + for pyobj in $pyobjects { + if pyobj.is_none($py) { + builder.append_null(); + } else { + builder.append_value(pyobj.extract::<$elem_type>($py)?); + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +#[allow(unused_macros)] +macro_rules! build_json_array { + ($py:expr, $pyobjects:expr) => {{ + let json_dumps = $py.eval_bound("json.dumps", None, None)?; + let mut builder = StringBuilder::with_capacity($pyobjects.len(), 1024); + for pyobj in $pyobjects { + if pyobj.is_none($py) { + builder.append_null(); + continue; + }; + let json_str = json_dumps.call1((pyobj,))?; + builder.append_value(json_str.extract::<&str>()?); + } + Ok(Arc::new(builder.finish())) + }}; +} + +pub struct Converter {} + +impl Converter { + /// Get array element as a python object. + pub fn get_pyobject(py: Python<'_>, array: &dyn Array, i: usize) -> PyResult { + if array.is_null(i) { + return Ok(py.None()); + } + Ok(match array.data_type() { + DataType::Null => py.None(), + DataType::Boolean => get_pyobject!(BooleanArray, py, array, i), + DataType::Int8 => get_pyobject!(Int8Array, py, array, i), + DataType::Int16 => get_pyobject!(Int16Array, py, array, i), + DataType::Int32 => get_pyobject!(Int32Array, py, array, i), + DataType::Int64 => get_pyobject!(Int64Array, py, array, i), + DataType::UInt8 => get_pyobject!(UInt8Array, py, array, i), + DataType::UInt16 => get_pyobject!(UInt16Array, py, array, i), + DataType::UInt32 => get_pyobject!(UInt32Array, py, array, i), + DataType::UInt64 => get_pyobject!(UInt64Array, py, array, i), + DataType::Float32 => get_pyobject!(Float32Array, py, array, i), + DataType::Float64 => get_pyobject!(Float64Array, py, array, i), + // DataType::Utf8 => match field.metadata().get("ARROW:extension:name") { + // Some(x) if x == "arroyo.json" => { + // let array = array.as_any().downcast_ref::().unwrap(); + // let string = array.value(i); + // // XXX: it is slow to call eval every time + // let json_loads = py.eval_bound("json.loads", None, None)?; + // json_loads.call1((string,))?.into() + // } + // _ => get_pyobject!(StringArray, py, array, i), + // }, + DataType::Utf8 => get_pyobject!(StringArray, py, array, i), + DataType::LargeUtf8 => get_pyobject!(LargeStringArray, py, array, i), + DataType::Binary => get_pyobject!(BinaryArray, py, array, i), + DataType::LargeBinary => get_pyobject!(LargeBinaryArray, py, array, i), + DataType::List(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + let list = array.value(i); + let mut values = Vec::with_capacity(list.len()); + for j in 0..list.len() { + values.push(Self::get_pyobject(py, list.as_ref(), j)?); + } + values.into_py(py) + } + DataType::Struct(fields) => { + let array = array.as_any().downcast_ref::().unwrap(); + let object = py.eval_bound("Struct()", None, None)?; + for (j, field) in fields.iter().enumerate() { + let value = Self::get_pyobject(py, array.column(j).as_ref(), i)?; + object.setattr(field.name().as_str(), value)?; + } + object.into() + } + other => { + return Err(PyTypeError::new_err(format!( + "Unimplemented datatype {}", + other + ))) + } + }) + } + + /// Build arrow array from python objects. + pub fn build_array( + data_type: &DataType, + py: Python<'_>, + values: &[PyObject], + ) -> PyResult { + match data_type { + DataType::Null => build_array!(NullBuilder, py, values), + DataType::Boolean => build_array!(BooleanBuilder, py, values), + DataType::Int8 => build_array!(Int8Builder, py, values), + DataType::Int16 => build_array!(Int16Builder, py, values), + DataType::Int32 => build_array!(Int32Builder, py, values), + DataType::Int64 => build_array!(Int64Builder, py, values), + DataType::UInt8 => build_array!(UInt8Builder, py, values), + DataType::UInt16 => build_array!(UInt16Builder, py, values), + DataType::UInt32 => build_array!(UInt32Builder, py, values), + DataType::UInt64 => build_array!(UInt64Builder, py, values), + DataType::Float32 => build_array!(Float32Builder, py, values), + DataType::Float64 => build_array!(Float64Builder, py, values), + DataType::Utf8 => build_array!(StringBuilder, &str, py, values), + // DataType::Utf8 => match field.metadata().get("ARROW:extension:name") { + // Some(x) if x == "arroyo.json" => { + // build_json_array!(py, values) + // } + // _ => build_array!(StringBuilder, &str, py, values), + // }, + DataType::LargeUtf8 => build_array!(LargeStringBuilder, &str, py, values), + DataType::Binary => build_array!(BinaryBuilder, &[u8], py, values), + DataType::LargeBinary => build_array!(LargeBinaryBuilder, &[u8], py, values), + // list + DataType::List(inner) => { + // flatten lists + let mut flatten_values = vec![]; + let mut offsets = Vec::::with_capacity(values.len() + 1); + offsets.push(0); + for val in values { + if !val.is_none(py) { + let array = val.bind(py); + flatten_values.reserve(array.len()?); + for elem in array.iter()? { + flatten_values.push(elem?.into()); + } + } + offsets.push(flatten_values.len() as i32); + } + let values_array = Self::build_array(inner.data_type(), py, &flatten_values)?; + let nulls = values.iter().map(|v| !v.is_none(py)).collect(); + Ok(Arc::new(ListArray::new( + inner.clone(), + OffsetBuffer::new(offsets.into()), + values_array, + Some(nulls), + ))) + } + // large list + DataType::LargeList(inner) => { + // flatten lists + let mut flatten_values = vec![]; + let mut offsets = Vec::::with_capacity(values.len() + 1); + offsets.push(0); + for val in values { + if !val.is_none(py) { + let array = val.bind(py); + flatten_values.reserve(array.len()?); + for elem in array.iter()? { + flatten_values.push(elem?.into()); + } + } + offsets.push(flatten_values.len() as i64); + } + let values_array = Self::build_array(inner.data_type(), py, &flatten_values)?; + let nulls = values.iter().map(|v| !v.is_none(py)).collect(); + Ok(Arc::new(LargeListArray::new( + inner.clone(), + OffsetBuffer::new(offsets.into()), + values_array, + Some(nulls), + ))) + } + DataType::Struct(fields) => { + let mut arrays = Vec::with_capacity(fields.len()); + for field in fields { + let mut field_values = Vec::with_capacity(values.len()); + for val in values { + let v = if val.is_none(py) { + py.None() + } else if let Ok(value) = val.getattr(py, field.name().as_str()) { + value + } else { + val.bind(py).get_item(field.name().as_str())?.into() + }; + field_values.push(v); + } + arrays.push(Self::build_array(field.data_type(), py, &field_values)?); + } + let nulls = values.iter().map(|v| !v.is_none(py)).collect(); + Ok(Arc::new(StructArray::new( + fields.clone(), + arrays, + Some(nulls), + ))) + } + other => Err(PyTypeError::new_err(format!( + "Unimplemented datatype {}", + other + ))), + } + } +} diff --git a/crates/arroyo-udf/arroyo-udf-python/src/threaded.rs b/crates/arroyo-udf/arroyo-udf-python/src/threaded.rs new file mode 100644 index 000000000..70a8a2c9c --- /dev/null +++ b/crates/arroyo-udf/arroyo-udf-python/src/threaded.rs @@ -0,0 +1,191 @@ +use crate::interpreter::SubInterpreter; +use crate::pyarrow::Converter; +use crate::{extract_type_info, PythonUDF, UDF_PY_LIB}; +use anyhow::anyhow; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::DataType; +use arroyo_udf_common::parse::NullableType; +use datafusion::logical_expr::{Signature, TypeSignature, Volatility}; +use itertools::Itertools; +use pyo3::prelude::*; +use pyo3::types::{PyFunction, PyList, PyString, PyTuple}; +use std::sync::{Arc, Mutex}; +use std::thread; + +pub struct ThreadedUdfInterpreter {} + +impl ThreadedUdfInterpreter { + pub async fn new(body: Arc) -> anyhow::Result { + let (task_tx, task_rx) = std::sync::mpsc::sync_channel(0); + let (result_tx, result_rx) = std::sync::mpsc::sync_channel(0); + let (parse_tx, parse_rx) = std::sync::mpsc::sync_channel(0); + + thread::spawn({ + let body = body.clone(); + move || { + let interpreter = SubInterpreter::new().unwrap(); + let (name, arg_types, ret) = match Self::parse(&interpreter, &*body) { + Ok(p) => p, + Err(e) => { + parse_tx.send(Err(anyhow!("{}", e.to_string()))).unwrap(); + return; + } + }; + + parse_tx + .send(Ok((name.clone(), arg_types.clone(), ret.clone()))) + .unwrap(); + + loop { + match task_rx.recv() { + Ok(args) => { + result_tx + .send(Self::execute( + &interpreter, + &name, + &arg_types, + args, + &ret.data_type, + )) + .expect("python result queue closed"); + } + Err(_) => { + break; + } + } + } + } + }); + + let (name, arg_types, return_type) = parse_rx.recv()??; + + let type_signature = Self::get_typesignature(&arg_types); + + Ok(PythonUDF { + name, + task_tx, + result_rx: Arc::new(Mutex::new(result_rx)), + definition: body, + signature: Arc::new(Signature { + type_signature, + volatility: Volatility::Volatile, + }), + arg_types, + return_type, + }) + } + + fn get_typesignature(args: &[NullableType]) -> TypeSignature { + let input = args + .iter() + .map(|arg| Self::get_alternatives(&arg.data_type)) + .multi_cartesian_product() + .map(TypeSignature::Exact) + .collect(); + + TypeSignature::OneOf(input) + } + + fn get_alternatives(dt: &DataType) -> Vec { + match dt { + DataType::Int64 => vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + ], + DataType::Float64 => vec![DataType::Float32, DataType::Float64], + _ => vec![dt.clone()], + } + } + + fn execute( + interpreter: &SubInterpreter, + name: &str, + arg_types: &[NullableType], + args: Vec, + ret_type: &DataType, + ) -> anyhow::Result { + interpreter + .with_gil(|py| { + let function = py.eval_bound(name, None, None).unwrap(); + let function = function.downcast::().unwrap(); + + let size = args.get(0).map(|e| e.len()).unwrap_or(0); + + let results: anyhow::Result> = (0..size) + .map(|i| { + let args: anyhow::Result> = args + .iter() + .map(|a| { + Converter::get_pyobject(py, &*a, i).map_err(|e| { + anyhow!("Could not convert datatype to python: {}", e) + }) + }) + .collect(); + + let mut args = args?; + + // we don't call the UDF on null arguments unless it's declared with an + // optional type + if args + .iter() + .zip(arg_types) + .any(|(arg, t)| arg.is_none(py) && !t.nullable) + { + Ok(py.None()) + } else { + let args = PyTuple::new_bound(py, args.drain(..)); + + function + .call1(args) + .map_err(|e| { + anyhow!("failed while calling Python UDF '{}': {}", &name, e) + }) + .map(|r| r.into()) + } + }) + .collect(); + + Converter::build_array(&ret_type, py, &results?).map_err(|e| { + anyhow!( + "could not convert results from Python UDF '{}' to arrow: {}", + name, + e + ) + .into() + }) + }) + .map_err(|e| e.into()) + } + + fn parse( + interpreter: &SubInterpreter, + body: &str, + ) -> anyhow::Result<(Arc, Arc>, Arc)> { + interpreter.with_gil(|py| { + let lib = PyModule::from_code_bound(py, UDF_PY_LIB, "arroyo_udf", "arroyo_udf")?; + + py.run_bound(&body, None, None)?; + + let udfs = lib.call_method0( "get_udfs")?; + let udfs: &Bound = udfs.downcast().unwrap(); + + match udfs.len() { + 0 => Err(anyhow!("The supplied code does not contain a UDF (UDF functions must be annotated with @udf)").into()), + 1 => { + let udf = udfs.get_item(0)?; + let name = udf.getattr("__name__")?.downcast::().unwrap() + .to_string(); + let (args, ret) = extract_type_info(&udfs.get_item(0).unwrap())?; + Ok((Arc::new(name), Arc::new(args), Arc::new(ret))) + } + _ => Err(anyhow!("More than one function was annotated with @udf, which is not supported").into()), + } + }).map_err(|e| e.into()) + } +} diff --git a/crates/arroyo-worker/src/lib.rs b/crates/arroyo-worker/src/lib.rs index 56efe7073..f3bc01b9f 100644 --- a/crates/arroyo-worker/src/lib.rs +++ b/crates/arroyo-worker/src/lib.rs @@ -425,6 +425,16 @@ impl WorkerGrpc for WorkerServer { } } + for (udf_name, python_udf) in &logical.program_config.python_udfs { + info!("Loading Python UDF {}", udf_name); + registry.add_python_udf(python_udf).await.map_err(|e| { + Status::failed_precondition( + e.context(format!("loading Python UDF {udf_name}")) + .to_string(), + ) + })?; + } + let (engine, control_rx) = { let network = { self.network.lock().unwrap().take().unwrap() }; diff --git a/crates/arroyo/src/main.rs b/crates/arroyo/src/main.rs index 6486393b4..ed35ef558 100644 --- a/crates/arroyo/src/main.rs +++ b/crates/arroyo/src/main.rs @@ -240,6 +240,16 @@ fn sqlite_connection() -> rusqlite::Connection { [&uuid], ) .expect("Unable to write to sqlite database"); + } else { + // migrate database + if let Err(e) = sqlite_migrations::migrations::runner().run(&mut conn) { + error!("Unable to migrate database to latest schema: {e}"); + error!( + "To continue, delete or move the existing database at '{}'", + path.to_string_lossy() + ); + exit(1); + } } let mut statement = conn.prepare("select id from cluster_info").unwrap(); diff --git a/webui/src/gen/api-types.ts b/webui/src/gen/api-types.ts index 1795ccc8e..114cc2622 100644 --- a/webui/src/gen/api-types.ts +++ b/webui/src/gen/api-types.ts @@ -274,6 +274,7 @@ export interface components { description?: string | null; dylibUrl: string; id: string; + language: components["schemas"]["UdfLanguage"]; name: string; prefix: string; /** Format: int64 */ @@ -501,10 +502,14 @@ export interface components { TimestampFormat: "rfc3339" | "unix_millis"; Udf: { definition: string; + language?: components["schemas"]["UdfLanguage"]; }; + /** @enum {string} */ + UdfLanguage: "python" | "rust"; UdfPost: { definition: string; description?: string | null; + language?: components["schemas"]["UdfLanguage"]; prefix: string; }; UdfValidationResult: { @@ -517,6 +522,7 @@ export interface components { }; ValidateUdfPost: { definition: string; + language?: components["schemas"]["UdfLanguage"]; }; }; responses: never; diff --git a/webui/src/lib/data_fetching.ts b/webui/src/lib/data_fetching.ts index 64a4f68d8..ba8a5fa35 100644 --- a/webui/src/lib/data_fetching.ts +++ b/webui/src/lib/data_fetching.ts @@ -129,8 +129,8 @@ const queryValidationKey = (query?: string, localUdfs?: LocalUdf[]) => { return query != undefined ? { key: 'PipelineGraph', query, localUdfs } : null; }; -const udfValidationKey = (definition: string) => { - return { key: 'UdfValidation', definition }; +const udfValidationKey = (definition: string, language: 'python' | 'rust') => { + return { key: 'UdfValidation', definition, language }; }; const pipelineKey = (pipelineId?: string) => { @@ -377,7 +377,7 @@ const queryValidationFetcher = () => { let udfs: PipelineLocalUdf[] = []; if (params.localUdfs) { udfs = params.localUdfs.map(udf => { - return { definition: udf.definition }; + return { definition: udf.definition, language: udf.language }; }); } @@ -407,12 +407,13 @@ export const useQueryValidation = (query?: string, localUdfs?: LocalUdf[]) => { const udfValidationFetcher = () => { const controller = useRef(); - return async (params: { key: string; definition: string }) => { + return async (params: { key: string; definition: string; language: 'python' | 'rust' }) => { controller.current?.abort(); controller.current = new AbortController(); const { data, error } = await post('/v1/udfs/validate', { body: { definition: params.definition, + language: params.language, }, signal: controller.current?.signal, }); @@ -422,10 +423,11 @@ const udfValidationFetcher = () => { export const useUdfValidation = ( onSuccess: (data: UdfValidationResult, key: any, config: any) => void, - definition: string + definition: string, + language: 'rust' | 'python' ) => { const { data, error, isLoading } = useSWR( - udfValidationKey(definition), + udfValidationKey(definition, language), udfValidationFetcher(), { revalidateOnFocus: false, revalidateIfStale: false, shouldRetryOnError: false, onSuccess } ); @@ -673,11 +675,17 @@ export const useGlobalUdfs = () => { udfsFetcher ); - const createGlobalUdf = async (prefix: string, definition: string, description: string) => { + const createGlobalUdf = async ( + prefix: string, + definition: string, + language: 'python' | 'rust', + description: string + ) => { const { data, error } = await post('/v1/udfs', { body: { prefix, definition, + language, description, }, }); diff --git a/webui/src/routes/pipelines/CreatePipeline.tsx b/webui/src/routes/pipelines/CreatePipeline.tsx index 3063eee38..58effea19 100644 --- a/webui/src/routes/pipelines/CreatePipeline.tsx +++ b/webui/src/routes/pipelines/CreatePipeline.tsx @@ -151,6 +151,7 @@ export function CreatePipeline() { id: name, name, // this gets updated after validation definition: u.definition, + language: u.language!, open: false, errors: [], }; @@ -162,6 +163,7 @@ export function CreatePipeline() { const { data: udfValidation } = await post('/v1/udfs/validate', { body: { definition: udf.definition, + language: udf.language, }, }); @@ -220,6 +222,7 @@ export function CreatePipeline() { const { data: udfsValiation, error: udfsValiationError } = await post('/v1/udfs/validate', { body: { definition: udf.definition, + language: udf.language, }, }); @@ -246,6 +249,7 @@ export function CreatePipeline() { const queryValid = async () => { const udfs: PipelineLocalUdf[] = localUdfs.map(u => ({ definition: u.definition, + language: u.language, })); const { data: queryValidation } = await post('/v1/pipelines/validate_query', { body: { @@ -295,6 +299,7 @@ export function CreatePipeline() { const udfs: PipelineLocalUdf[] = localUdfs.map(u => ({ definition: u.definition, + language: u.language, })); const { data: newPipeline, error } = await post('/v1/pipelines/preview', { @@ -331,6 +336,7 @@ export function CreatePipeline() { console.log('starting'); const udfs: PipelineLocalUdf[] = localUdfs.map(u => ({ definition: u.definition, + language: u.language, })); const { data, error } = await post('/v1/pipelines', { diff --git a/webui/src/routes/udfs/GlobalizeModal.tsx b/webui/src/routes/udfs/GlobalizeModal.tsx index 880d1725a..03ae7778d 100644 --- a/webui/src/routes/udfs/GlobalizeModal.tsx +++ b/webui/src/routes/udfs/GlobalizeModal.tsx @@ -45,7 +45,7 @@ const GlobalizeModal: React.FC = ({ isOpen, onClose, udf }) } const share = async () => { - const { error } = await createGlobalUdf(cleanPrefix, udf.definition, description); + const { error } = await createGlobalUdf(cleanPrefix, udf.definition, udf.language, description); if (error) { setCreateError(formatError(error)); diff --git a/webui/src/routes/udfs/UdfEditTab.tsx b/webui/src/routes/udfs/UdfEditTab.tsx index a220699be..e166410c7 100644 --- a/webui/src/routes/udfs/UdfEditTab.tsx +++ b/webui/src/routes/udfs/UdfEditTab.tsx @@ -69,6 +69,7 @@ const UdfEditTab: React.FC = ({ udf }) => { const { data: udfValiation } = await post('/v1/udfs/validate', { body: { definition: udf.definition, + language: udf.language, }, }); setLoading(false); diff --git a/webui/src/routes/udfs/UdfEditor.tsx b/webui/src/routes/udfs/UdfEditor.tsx index 18bc31f52..526360e67 100644 --- a/webui/src/routes/udfs/UdfEditor.tsx +++ b/webui/src/routes/udfs/UdfEditor.tsx @@ -27,7 +27,7 @@ const UdfEditor: React.FC = ({ udf }) => { } }; - useUdfValidation(updateName, definitionToCheck); + useUdfValidation(updateName, definitionToCheck, udf.language); return ( = ({ udf }) => { setLocalDefinition(s); debounceSetCheck(s); }} - language="rust" + language={udf.language} /> ); }; diff --git a/webui/src/routes/udfs/UdfLabel.tsx b/webui/src/routes/udfs/UdfLabel.tsx index 48a9df9eb..82f9fb54e 100644 --- a/webui/src/routes/udfs/UdfLabel.tsx +++ b/webui/src/routes/udfs/UdfLabel.tsx @@ -1,6 +1,6 @@ import React, { useContext } from 'react'; import { Text, Flex, Icon } from '@chakra-ui/react'; -import { DiRust } from 'react-icons/di'; +import { DiPython, DiRust } from 'react-icons/di'; import { LocalUdf, LocalUdfsContext, nameRoot } from '../../udf_state'; import { GlobalUdf } from '../../lib/data_fetching'; import '@fontsource/ibm-plex-mono'; @@ -15,7 +15,7 @@ const UdfLabel: React.FC = ({ udf }) => { return ( openTab(udf)} cursor={'pointer'} w={'min-content'}> - + = () => { ))} - - - + + ); diff --git a/webui/src/udf_state.ts b/webui/src/udf_state.ts index 2c408173b..0ef1ca9a8 100644 --- a/webui/src/udf_state.ts +++ b/webui/src/udf_state.ts @@ -6,6 +6,7 @@ import { generate_udf_id } from './lib/util'; export interface LocalUdf { name: string; definition: string; + language: 'python' | 'rust'; id: string; open: boolean; errors?: string[]; @@ -26,7 +27,7 @@ export const LocalUdfsContext = React.createContext<{ update: { definition?: string; open?: boolean; name?: string } ) => void; isGlobal: (udf: LocalUdf | GlobalUdf) => boolean; - newUdf: () => void; + newUdf: (language: 'python' | 'rust') => void; editorTab: number; handleEditorTabChange: (index: number) => void; }>({ @@ -41,7 +42,7 @@ export const LocalUdfsContext = React.createContext<{ isOverridden: _ => false, updateLocalUdf: (_, __) => {}, isGlobal: _ => false, - newUdf: () => {}, + newUdf: (language: 'python' | 'rust') => {}, editorTab: 0, handleEditorTabChange: _ => {}, }); @@ -127,21 +128,35 @@ export const getLocalUdfsContextValue = () => { return globalUdfs != undefined && globalUdfs.some(g => g.id === udf.id); }; - const newUdf = () => { + const newUdf = (language: 'python' | 'rust') => { const id = generate_udf_id(); const functionName = `new_udf`; - const definition = - `/*\n` + - `[dependencies]\n\n` + - `*/\n\n` + - `use arroyo_udf_plugin::udf;\n\n` + - `#[udf]\n` + - `fn ${functionName}(x: i64) -> i64 {\n` + - ' // Write your function here\n' + - ' // Tip: rename the function to something descriptive\n\n' + - '}'; - - const newUdf = { name: functionName, definition, id, open: true }; + + const defaultRust = ` +/* +[dependencies] + +*/ + +use arroyo_udf_plugin::udf; + +#[udf] +fn ${functionName}(x: i64) -> i64 { + // Write your function here + // Tip: rename the function to something descriptive +}`; + + const defaultPython = ` +from arroyo_udf import udf + +@udf +def ${functionName}(x: int) -> int: + # Write your function here + # Tip: rename the function to something descriptive`; + + let definition = language == 'python' ? defaultPython : defaultRust; + + const newUdf = { name: functionName, definition, id, language, open: true }; const newLocalUdfs = [...localUdfs, newUdf]; setLocalUdfs(newLocalUdfs); setSelectedUdfId(id);