diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..d7e0ee2 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Keep Cargo's workspace output out of `target/` so `mvn clean` (which deletes +# the root `target/`) does not nuke the Rust build cache. +[build] +target-dir = "rust-target" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c5db936..da8e65a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,8 +83,8 @@ jobs: path: | ~/.cargo/registry ~/.cargo/git - native/target - key: ${{ runner.os }}-cargo-${{ hashFiles('native/Cargo.lock') }} + rust-target + key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo- - name: Build native and run tests diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4cf628f..952bf34 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -54,7 +54,7 @@ jobs: run: ./mvnw -q spotless:check - name: Check Rust formatting - run: cd native && cargo fmt --all -- --check + run: cargo fmt --all -- --check clippy: name: Clippy @@ -81,9 +81,9 @@ jobs: path: | ~/.cargo/registry ~/.cargo/git - native/target - key: ${{ runner.os }}-clippy-${{ hashFiles('native/Cargo.lock') }} + rust-target + key: ${{ runner.os }}-clippy-${{ hashFiles('Cargo.lock') }} restore-keys: ${{ runner.os }}-clippy- - name: Run clippy - run: cd native && cargo clippy --all-targets -- -D warnings + run: cargo clippy --workspace --all-targets -- -D warnings diff --git a/.gitignore b/.gitignore index 719a2a4..25c9216 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target/ +rust-target/ *.class .idea/ .vscode/ diff --git a/native/Cargo.lock b/Cargo.lock similarity index 94% rename from native/Cargo.lock rename to Cargo.lock index 8c56280..286f96f 100644 --- a/native/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ dependencies = [ [[package]] name = "ar_archive_writer" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +checksum = "4087686b4b0a3427190bae57a1d9a478dbb2d40c5dc1bd6e2b6d797913bdd348" dependencies = [ "object", ] @@ -119,9 +119,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "607e64bb911ee4f90483e044fe78f175989148c2892e659a2cd25429e782ec54" +checksum = "378530e55cd479eda3c14eb345310799717e6f76d0c332041e8487022166b471" dependencies = [ "arrow-arith", "arrow-array", @@ -140,9 +140,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e754319ed8a85d817fe7adf183227e0b5308b82790a737b426c1124626b48118" +checksum = "a0ab212d2c1886e802f51c5212d78ebbcbb0bec980fff9dadc1eb8d45cd0b738" dependencies = [ "arrow-array", "arrow-buffer", @@ -154,9 +154,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841321891f247aa86c6112c80d83d89cb36e0addd020fa2425085b8eb6c3f579" +checksum = "cfd33d3e92f207444098c75b42de99d329562be0cf686b307b097cc52b4e999e" dependencies = [ "ahash", "arrow-buffer", @@ -173,9 +173,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f955dfb73fae000425f49c8226d2044dab60fb7ad4af1e24f961756354d996c9" +checksum = "0c6cd424c2693bcdbc150d843dc9d4d137dd2de4782ce6df491ad11a3a0416c0" dependencies = [ "bytes", "half", @@ -185,9 +185,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5e686972523798f76bef355145bc1ae25a84c731e650268d31ab763c701663" +checksum = "4c5aefb56a2c02e9e2b30746241058b85f8983f0fcff2ba0c6d09006e1cded7f" dependencies = [ "arrow-array", "arrow-buffer", @@ -207,9 +207,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86c276756867fc8186ec380c72c290e6e3b23a1d4fb05df6b1d62d2e62666d48" +checksum = "e94e8cf7e517657a52b91ea1263acf38c4ca62a84655d72458a3359b12ab97de" dependencies = [ "arrow-array", "arrow-cast", @@ -222,9 +222,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3b5846209775b6dc8056d77ff9a032b27043383dd5488abd0b663e265b9373" +checksum = "3c88210023a2bfee1896af366309a3028fc3bcbd6515fa29a7990ee1baa08ee0" dependencies = [ "arrow-buffer", "arrow-schema", @@ -235,9 +235,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd8907ddd8f9fbabf91ec2c85c1d81fe2874e336d2443eb36373595e28b98dd5" +checksum = "238438f0834483703d88896db6fe5a7138b2230debc31b34c0336c2996e3c64f" dependencies = [ "arrow-array", "arrow-buffer", @@ -251,9 +251,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4518c59acc501f10d7dcae397fe12b8db3d81bc7de94456f8a58f9165d6f502" +checksum = "205ca2119e6d679d5c133c6f30e68f027738d95ed948cf77677ea69c7800036b" dependencies = [ "arrow-array", "arrow-buffer", @@ -276,9 +276,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efa70d9d6b1356f1fb9f1f651b84a725b7e0abb93f188cf7d31f14abfa2f2e6f" +checksum = "1bffd8fd2579286a5d63bac898159873e5094a79009940bcb42bbfce4f19f1d0" dependencies = [ "arrow-array", "arrow-buffer", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faec88a945338192beffbbd4be0def70135422930caa244ac3cec0cd213b26b4" +checksum = "bab5994731204603c73ba69267616c50f80780774c6bb0476f1f830625115e0c" dependencies = [ "arrow-array", "arrow-buffer", @@ -302,9 +302,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18aa020f6bc8e5201dcd2d4b7f98c68f8a410ef37128263243e6ff2a47a67d4f" +checksum = "f633dbfdf39c039ada1bf9e34c694816eb71fbb7dc78f613993b7245e078a1ed" dependencies = [ "bitflags", "serde_core", @@ -313,9 +313,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a657ab5132e9c8ca3b24eb15a823d0ced38017fe3930ff50167466b02e2d592c" +checksum = "8cd065c54172ac787cf3f2f8d4107e0d3fdc26edba76fdf4f4cc170258942222" dependencies = [ "ahash", "arrow-array", @@ -327,9 +327,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6de2efbbd1a9f9780ceb8d1ff5d20421b35863b361e3386b4f571f1fc69fcb8" +checksum = "29dd7cda3ab9692f43a2e4acc444d760cc17b12bb6d8232ddf64e9bab7c06b42" dependencies = [ "arrow-array", "arrow-buffer", @@ -393,9 +393,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "base64" @@ -419,9 +419,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.11.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" [[package]] name = "blake2" @@ -457,9 +457,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.9.1" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +checksum = "b2f04f6fef12d70d42a77b1433c9e0f065238479a6cefc4f5bab105e9873a3c3" dependencies = [ "bon-macros", "rustversion", @@ -467,9 +467,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.9.1" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +checksum = "7d0bd4c2f75335ad98052a37efb54f428b492f64340257143b3429c8a508fa7b" dependencies = [ "darling", "ident_case", @@ -482,9 +482,9 @@ dependencies = [ [[package]] name = "brotli" -version = "8.0.2" +version = "8.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +checksum = "8119e4516436f5708bbc474a9d395bf12f1b5395e93a92a56e647ac3388c8610" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -493,9 +493,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "5.0.0" +version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +checksum = "5962523e1b92ce1b5e793d9169b9943eece10d39f62550bc04bb605d75b94924" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -503,9 +503,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "byteorder" @@ -530,9 +530,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.62" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "jobserver", @@ -571,9 +571,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "num-traits", @@ -789,9 +789,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "6.1.0" +version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -1299,6 +1299,16 @@ dependencies = [ "datafusion-physical-expr-common", ] +[[package]] +name = "datafusion-java-example-bridge" +version = "0.1.0" +dependencies = [ + "arrow", + "datafusion", + "datafusion-spark-bridge", + "tokio", +] + [[package]] name = "datafusion-jni" version = "0.1.0" @@ -1306,6 +1316,7 @@ dependencies = [ "arrow", "async-trait", "datafusion", + "datafusion-jni-common", "datafusion-proto", "datafusion-substrait", "futures", @@ -1319,6 +1330,16 @@ dependencies = [ "url", ] +[[package]] +name = "datafusion-jni-common" +version = "0.1.0" +dependencies = [ + "datafusion", + "futures", + "jni", + "tokio", +] + [[package]] name = "datafusion-macros" version = "53.1.0" @@ -1527,6 +1548,21 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "datafusion-spark-bridge" +version = "0.1.0" +dependencies = [ + "arrow", + "async-trait", + "datafusion", + "datafusion-jni-common", + "datafusion-proto", + "futures", + "jni", + "prost", + "tokio", +] + [[package]] name = "datafusion-sql" version = "53.1.0" @@ -1579,9 +1615,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" dependencies = [ "proc-macro2", "quote", @@ -1596,9 +1632,9 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "equivalent" @@ -1904,9 +1940,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "http" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -1949,9 +1985,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", @@ -2241,13 +2277,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "f2025f20d7a4fa7785846e7b63d10a76d3f1cee98ee5cb79ea59703f95e42162" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] @@ -2316,9 +2351,9 @@ dependencies = [ [[package]] name = "libbz2-rs-sys" -version = "0.2.3" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a6a8c165077efc8f3a971534c50ea6a1a18b329ef4a66e897a7e3a1494565f" +checksum = "34b357333733e8260735ba5894eb928c02ecc69c78715f01a8019e7fa7f2db4c" [[package]] name = "libc" @@ -2375,9 +2410,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "lru-slab" @@ -2406,9 +2441,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" [[package]] name = "miniz_oxide" @@ -2422,9 +2457,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "wasi", @@ -2570,9 +2605,9 @@ dependencies = [ [[package]] name = "parquet" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d7efd3052f7d6ef601085559a246bc991e9a8cc77e02753737df6322ce35f1" +checksum = "5dafa7d01085b62a47dd0c1829550a0a36710ea9c4fe358a05a85477cec8a908" dependencies = [ "ahash", "arrow-array", @@ -2734,9 +2769,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +checksum = "528ac67416ff8646872a3c02cad9cc4ee5dc9f9540c9b10771855c95cb2e5ae1" dependencies = [ "bytes", "prost-derive", @@ -2744,9 +2779,9 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +checksum = "03da047801ff44bb6a4d407d4860c05fd70bb81714e6b2f3812603d5b145b042" dependencies = [ "heck", "itertools", @@ -2763,9 +2798,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +checksum = "b570b25f7617e43d59005d0990ccb79e950a423952cea19671b7a876da390adf" dependencies = [ "anyhow", "itertools", @@ -2776,9 +2811,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +checksum = "f94967dc7688f3054c7fac87473ffae4cc4c3904800e2d9f5b857246d8963b0a" dependencies = [ "prost", ] @@ -3035,9 +3070,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -3064,9 +3099,9 @@ checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "regress" @@ -3178,9 +3213,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +checksum = "dab5152771c58876a2146916e53e35057e1a4dfa2b9df0f0305b07f611fdea4d" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -3361,9 +3396,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -3422,9 +3457,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "simd-adler32" @@ -3464,9 +3499,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -3861,9 +3896,9 @@ checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "typify" @@ -3920,9 +3955,9 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-segmentation" -version = "1.13.2" +version = "1.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" +checksum = "c6f5d3c3b1bf09027a88a6bc961fc00497d651009560b5463668dc81b0fa87a8" [[package]] name = "unicode-width" @@ -3968,9 +4003,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -4029,9 +4064,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "a254a4b10c19a76f09a27640e7ffbf9bc30bf67e16a3bf28aaefa4920fe81563" dependencies = [ "cfg-if", "once_cell", @@ -4042,9 +4077,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "54568702fabf5d4849ce2b90fadfa64168a097eaf4b351ce9df8b687a0086aaf" dependencies = [ "js-sys", "wasm-bindgen", @@ -4052,9 +4087,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "24a40fc75b0ec6f3746ceb10d36f53a93dcd68a93b11b6445983945d79eba0dc" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4062,9 +4097,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "908f34bd9b9ce3d4caf07b72dfab63d61504d156856c6bd3cd87fa350cf3985b" dependencies = [ "bumpalo", "proc-macro2", @@ -4075,9 +4110,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "7acbf7616c27b194bbb550bf77ed0c2c3e5b7fd1260a93082b95fb7f47959b92" dependencies = [ "unicode-ident", ] @@ -4131,9 +4166,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6e0871acf327f283dc6da28a1696cdc64fb355ba9f935d052021fa77f35cce69" dependencies = [ "js-sys", "wasm-bindgen", @@ -4541,9 +4576,9 @@ checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -4564,18 +4599,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" dependencies = [ "proc-macro2", "quote", @@ -4584,9 +4619,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..be906aa --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[workspace] +resolver = "2" +members = [ + "native", + "native-common", + "examples/native", + "spark/bridge", +] + +# Every dependency used by any workspace member is declared here so version +# bumps live in one place and the resolver picks a single version of each +# crate across the workspace. Members reference these via `{ workspace = true }` +# and add per-crate flags (optional, features, default-features) at the use +# site. +[workspace.dependencies] +arrow = { version = "58", features = ["ffi"] } +async-trait = "0.1" +datafusion = { version = "53.1.0" } +datafusion-proto = "53.1.0" +datafusion-substrait = "53.1.0" +futures = "0.3" +jni = "0.21" +# Pinned to the major DataFusion 53.1 pulls in transitively (0.13.x) so we +# share the same `dyn ObjectStore` vtable and don't double-link. +object_store = { version = "0.13", default-features = false } +prost = "0.14" +prost-build = "0.14" +protoc-bin-vendored = "3" +tokio = { version = "1", features = ["rt-multi-thread"] } +# Optional, cfg-gated. See `native/Cargo.toml` for the build-flag dance. +tokio-metrics = "0.5" +url = "2" diff --git a/Makefile b/Makefile index 6d9b0ae..d6bcf2c 100644 --- a/Makefile +++ b/Makefile @@ -20,14 +20,14 @@ all: native jvm native: - cd native && cargo build + cargo build --workspace -# Build the native crate with the `runtime-metrics` Cargo feature enabled. +# Build the JNI crate with the `runtime-metrics` Cargo feature enabled. # Requires `--cfg tokio_unstable` because tokio-metrics gates its API there. # Default `make native` does not pull this in; callers who need # SessionContext.runtimeStats() pick this target explicitly. native-runtime-metrics: - cd native && RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics + RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics jvm: ./mvnw package -DskipTests @@ -39,10 +39,10 @@ test: native # `:check` form inline in .github/workflows/lint.yml. format: ./mvnw -q spotless:apply - cd native && cargo fmt --all + cargo fmt --all clean: - cd native && cargo clean + cargo clean ./mvnw clean tpch-data: diff --git a/core/pom.xml b/core/pom.xml index 5ddf107..1e25736 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -102,8 +102,8 @@ under the License. - + value="${maven.multiModuleProjectDirectory}/rust-target/${datafusion.native.profile}/${datafusion.lib.filename}"/> + diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index ffc58dd..27d2b16 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -113,10 +113,11 @@ public DataFrame fromProto(byte[] planBytes) { * other Substrait-emitting tool — and hand them to DataFusion without round-tripping through SQL. * *

Substrait support is gated behind the {@code substrait} Cargo feature on the native crate - * and is off by default. Rebuild the native crate with {@code cargo build - * --features substrait} (or {@code cargo build --features substrait,protoc} for hermetic builds - * that vendor {@code protoc} via {@code cmake}) to enable it. If invoked against a native binary - * built without the feature, this method throws {@link RuntimeException} pointing at the flag. + * and is off by default. Rebuild the native crate with {@code cargo build -p + * datafusion-jni --features substrait} (or {@code ... --features substrait,protoc} for hermetic + * builds that vendor {@code protoc} via {@code cmake}) to enable it. If invoked against a native + * binary built without the feature, this method throws {@link RuntimeException} pointing at the + * flag. * * @throws IllegalArgumentException if {@code planBytes} is {@code null}. * @throws IllegalStateException if this context is closed. @@ -183,7 +184,7 @@ public MemoryUsage memoryUsage() { * Rebuild with: * *

{@code
-   * RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics
+   * RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics
    * }
* *

If invoked against a native binary built without the feature, this method throws {@link diff --git a/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java b/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java index 120d179..d567275 100644 --- a/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java +++ b/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java @@ -37,7 +37,7 @@ * #checkFeatureEnabled}. Run * *

{@code
- * (cd native && RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics)
+ * RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics
  * }
* * before {@code ./mvnw test} to exercise this class. diff --git a/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java b/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java index 34db3b5..a2cfb0a 100644 --- a/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java +++ b/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java @@ -50,7 +50,7 @@ * *

The {@code substrait} Cargo feature is off by default in {@code native/Cargo.toml}; if the * native crate was built without it, every test here is skipped (see {@link #checkFeatureEnabled}). - * Run {@code (cd native && cargo build --features substrait)} before {@code ./mvnw test} to + * Run {@code cargo build -p datafusion-jni --features substrait} before {@code ./mvnw test} to * exercise this class. */ class SessionContextSubstraitTest { diff --git a/docs/source/contributor-guide/development.md b/docs/source/contributor-guide/development.md index 984d77c..fdb00f4 100644 --- a/docs/source/contributor-guide/development.md +++ b/docs/source/contributor-guide/development.md @@ -42,7 +42,7 @@ This builds the native Rust crate and runs the JUnit tests. The steps can be run individually: ```sh -cd native && cargo build +cargo build --workspace ./mvnw test ``` @@ -74,14 +74,25 @@ disk space. The repository is a multi-module Maven build: -- `pom.xml` — parent POM declaring the `core` and `examples` modules and - shared plugin/dependency versions. +- `Cargo.toml` — Rust workspace root declaring the three crate members + (`native`, `native-common`, `examples/native`, `spark/bridge`) and `[workspace.dependencies]` + that pin shared versions in one place. Cargo writes artifacts to + `rust-target/` (overridden in `.cargo/config.toml`) so `mvn clean` at the + repo root does not nuke the Rust build cache. +- `pom.xml` — parent POM declaring the `core`, `spark`, and `examples` + modules and shared plugin/dependency versions. - `core/` — `datafusion-java` library module (Java sources, tests, and generated protobuf classes). +- `spark/` — `datafusion-java-spark` Spark DataSource V2 connector + (Scala + Java, pure JVM) and its `spark/bridge/` Rust SDK crate + (`datafusion-spark-bridge`: widening, scan machinery, `export_bridge!`). - `examples/` — `datafusion-java-examples` module containing runnable examples that depend on the library; built alongside the library so they - cannot fall out of sync with the API. -- `native/` — Rust crate (JNI + Arrow C Data Interface). + cannot fall out of sync with the API. Includes `examples/native/`, a + small `export_bridge!` cdylib used by the Spark connector demo + (`ExampleBridgeProviderFactory` + the pyspark script under + `examples/python/`). +- `native/` — `datafusion-jni` Rust crate (JNI + Arrow C Data Interface). - `proto/` — Protobuf definitions shared between Java and Rust. - `Makefile` — top-level build orchestration (`make test`, `make format`, `make tpch-data`). diff --git a/docs/source/contributor-guide/updating-datafusion-version.md b/docs/source/contributor-guide/updating-datafusion-version.md index 56d50dc..ef6cd10 100644 --- a/docs/source/contributor-guide/updating-datafusion-version.md +++ b/docs/source/contributor-guide/updating-datafusion-version.md @@ -21,7 +21,9 @@ under the License. Three things must move together when bumping DataFusion: -1. `native/Cargo.toml` — the `datafusion` crate dependency. +1. `Cargo.toml` (workspace root) — the `datafusion`, `datafusion-ffi`, + `datafusion-proto`, and `datafusion-substrait` entries in + `[workspace.dependencies]`. Members inherit from there. 2. `pom.xml` — the `` Maven property. **Must equal the Cargo version**; a mismatch means JVM-built protobuf plans won't deserialize on the native side. @@ -32,9 +34,9 @@ Three things must move together when bumping DataFusion: ## Recipe ```sh -# 1. Bump the Cargo dep -$EDITOR native/Cargo.toml # set datafusion = "" -(cd native && cargo update -p datafusion) +# 1. Bump the workspace dep +$EDITOR Cargo.toml # set datafusion = "" in [workspace.dependencies] +cargo update -p datafusion # 2. Bump the Maven property to match $EDITOR pom.xml # set diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..da9fec7 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,99 @@ +# DataFusion-Java examples + +Small, self-contained programs that each demonstrate one feature of the +DataFusion-Java API. Every example is a Java class with a `main` method that +builds a query against an in-process DataFusion engine and prints its result +(as tab-separated rows) to stdout. They are the fastest way to see what the +library can do and to copy a working starting point. + +## Prerequisites + +- JDK 17+ +- Maven (the repo ships `./mvnw`, no install needed) +- Rust toolchain (`cargo`) — the library calls into a native DataFusion + build, so the Rust side must be compiled once first + +## Build once + +From the repo root: + +```bash +# 1. Compile the native libraries (DataFusion + JNI glue). +cargo build --release + +# 2. Build the Java/Scala modules and install them into your local Maven repo. +./mvnw -B install -DskipTests -Drat.skip=true -Ddatafusion.native.profile=release +``` + +Step 2 must be `install`, not `package`: running an example below starts a +fresh Maven invocation that resolves `datafusion-java` from your local Maven +repository (`~/.m2/repository`), and only `install` publishes the jar there. +If you skip it you'll see +`Could not find artifact org.apache.datafusion:datafusion-java:...` — +that error means "run step 2". + +(If your local Maven repo lives somewhere non-standard, add +`-Dmaven.repo.local=/path/to/repo` to step 2 **and** to every run command.) + +## Run your first example + +```bash +./mvnw -B -pl examples exec:exec \ + -Dexec.mainClass=org.apache.datafusion.examples.SqlQueryExample +``` + +This registers a small CSV file, runs a SQL aggregation over it, and prints +the result rows. Swap `SqlQueryExample` for any class in the table below. + +> Why `exec:exec` and not `exec:java`? Each example runs in a fresh `java` +> process so the JVM flag the native Arrow integration needs +> (`--add-opens=java.base/java.nio=ALL-UNNAMED`) actually applies. `exec:java` +> would run inside Maven's own JVM without it. + +## The examples + +| Entry point (`-Dexec.mainClass=org.apache.datafusion.examples.<…>`) | Demonstrates | What you'll see | +| --- | --- | --- | +| `SqlQueryExample` | Register a CSV file, run a SQL aggregation | The aggregated rows printed as TSV | +| `DataFrameExample` | The DataFrame API: filter, group, sort — no SQL strings | The transformed rows | +| `ProtoPlanExample` | Build a DataFusion `LogicalPlanNode` protobuf in Java and execute it via `SessionContext.fromProto` — the wire-format path used by query frontends | The plan's result rows | +| `JdbcExample` | Pull rows from a JDBC source (in-memory H2) into Arrow, register them as a table, query them | Rows that originated in H2, queried through DataFusion | +| `AddOneExample` | Write a scalar UDF in Java and call it from SQL | Each input value, plus one | +| `NestedTypeUdfExample` | A scalar UDF whose input and output are nested Arrow types (`List`) | The transformed list column | + +## The Spark connector example + +One example is not a standalone `main`: +`ExampleBridgeProviderFactory` implements the Spark connector's +`BridgeProviderFactory` interface over a tiny in-memory table built inside +the example bridge cdylib (the `export_bridge!` crate under +[`native/`](native/)). It exists to be loaded *by Spark* — the runnable +end-to-end version is the PySpark demo under [`python/`](python/), and the +guide to building your own connector is +[`../spark/README.md`](../spark/README.md). + +To build its cdylib (workspace member, buildable from anywhere in the tree): + +```bash +cargo build -p datafusion-java-example-bridge --release +``` + +Building the examples jar then bundles the cdylib inside it (under +`org/apache/datafusion/examples///`), and the factory loads it from +there at runtime via the connector's `NativeLibraryLoader` — the same +packaging recipe a real bridge uses (see "Packaging your bridge" in +[`../spark/README.md`](../spark/README.md)). To run against an unpackaged +local build instead, pass +`-Dexample.bridge.lib.path=/abs/path/to/libdatafusion_example_bridge.{so,dylib}`. + +## Troubleshooting + +- **`Could not find artifact org.apache.datafusion:datafusion-java`** — the + parent wasn't installed to your local Maven repo. Re-run build step 2 + (`install`, not `package`). +- **`Native library not found ...`** — the Rust side wasn't built, or was + built in a different profile than Maven expects. Re-run build step 1 and + keep `-Ddatafusion.native.profile=release` consistent between the cargo + profile (`--release`) and the Maven flag. +- **`UnsatisfiedLinkError ... datafusion_example_bridge`** — only the example + bridge cdylib is missing; see "The Spark connector example" above. diff --git a/examples/native/Cargo.toml b/examples/native/Cargo.toml new file mode 100644 index 0000000..1e362cc --- /dev/null +++ b/examples/native/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +[package] +name = "datafusion-java-example-bridge" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +name = "datafusion_example_bridge" +# Built as a cdylib so the JVM loads it via NativeLibraryLoader; `rlib` keeps +# the Rust-level unit tests (options decoding, partition layout) runnable. +crate-type = ["cdylib", "rlib"] + +[dependencies] +arrow = { workspace = true } +datafusion = { workspace = true } +datafusion-spark-bridge = { path = "../../spark/bridge" } + +[dev-dependencies] +tokio = { workspace = true } diff --git a/examples/native/src/lib.rs b/examples/native/src/lib.rs new file mode 100644 index 0000000..b0b17e8 --- /dev/null +++ b/examples/native/src/lib.rs @@ -0,0 +1,279 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Example bridge cdylib: a small DataFusion `MemTable` exposed to Spark +//! through the `datafusion-spark-bridge` SDK. `export_bridge!` generates the +//! whole JNI surface for `org.apache.datafusion.examples.ExampleBridgeNative`; +//! this crate only decodes the options blob and builds the provider. +//! +//! The same pattern is what domain bridges (HDF5, custom Iceberg, in-house +//! formats) use to expose their TableProviders to Spark via the connector's +//! DataSource V2 plumbing. +//! +//! ## Options wire format +//! +//! The provider builder accepts an opaque `byte[]` that the JVM-side +//! `ExampleBridgeProviderFactory.encodeOptions` produces. Layout (little-endian): +//! +//! ```text +//! [u32 name_prefix_len][name_prefix UTF-8 bytes][u32 num_rows][u32 num_batches] +//! [u32 num_partitions][u8 shared_scan] <- optional trailing fields +//! ``` +//! +//! Empty/`null` bytes decode as all defaults: `name_prefix="row"`, `num_rows=4`, +//! `num_batches=1`, `num_partitions=1`, `shared_scan=false`. The trailing +//! fields are optional so blobs from older encoders keep decoding. The +//! `shared_scan` flag is consumed JVM-side (`ExampleBridgeProviderFactory.sharedScan`); +//! this decoder carries it only so one blob format serves both sides. Real +//! bridges can use the connector's default `OptionsCodec` instead (decoded via +//! `datafusion_spark_bridge::options`); this example hand-rolls the encoding +//! to show a custom wire layer. + +use std::sync::Arc; + +use arrow::array::{Float64Array, Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; +use datafusion::catalog::TableProvider; +use datafusion::datasource::MemTable; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +#[derive(Debug)] +struct Options { + name_prefix: String, + num_rows: u32, + num_batches: u32, + num_partitions: u32, +} + +impl Default for Options { + fn default() -> Self { + Self { + name_prefix: "row".to_string(), + num_rows: 4, + num_batches: 1, + num_partitions: 1, + } + } +} + +fn decode_options(bytes: &[u8]) -> Result> { + if bytes.is_empty() { + return Ok(Options::default()); + } + if bytes.len() < 4 { + return Err("options blob too short for name_prefix length prefix".into()); + } + let name_len = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize; + let name_end = 4 + name_len; + if bytes.len() < name_end + 8 { + return Err("options blob truncated: missing name_prefix bytes or trailing ints".into()); + } + let name_prefix = std::str::from_utf8(&bytes[4..name_end]) + .map_err(|e| format!("name_prefix is not valid UTF-8: {e}"))? + .to_string(); + let num_rows = u32::from_le_bytes(bytes[name_end..name_end + 4].try_into().unwrap()); + let num_batches = u32::from_le_bytes(bytes[name_end + 4..name_end + 8].try_into().unwrap()); + if num_rows == 0 || num_batches == 0 { + return Err("num_rows and num_batches must both be > 0".into()); + } + // Optional trailing fields (older encoders omit them): num_partitions, + // then the shared_scan flag byte, which only the JVM side interprets. + let num_partitions = if bytes.len() >= name_end + 12 { + u32::from_le_bytes(bytes[name_end + 8..name_end + 12].try_into().unwrap()) + } else { + 1 + }; + if num_partitions == 0 { + return Err("num_partitions must be > 0".into()); + } + Ok(Options { + name_prefix, + num_rows, + num_batches, + num_partitions, + }) +} + +/// Build the example schema + a multi-batch in-memory table sized per `opts`. +/// Row `r` in batch `b` gets `id = b * num_rows + r`, `name = ""`, +/// `value = id * 1.5` (with `value` left null for every fourth row so the demo +/// still exercises null handling). +fn build_mem_table( + opts: &Options, +) -> Result, Box> { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + Field::new("value", DataType::Float64, true), + ])); + + let mut batches = Vec::with_capacity(opts.num_batches as usize); + for b in 0..opts.num_batches { + let mut ids = Vec::with_capacity(opts.num_rows as usize); + let mut names: Vec> = Vec::with_capacity(opts.num_rows as usize); + let mut values: Vec> = Vec::with_capacity(opts.num_rows as usize); + for r in 0..opts.num_rows { + let id = (b as i64) * (opts.num_rows as i64) + (r as i64); + ids.push(id); + names.push(Some(format!("{}{}", opts.name_prefix, id))); + values.push(if id % 4 == 3 { + None + } else { + Some(id as f64 * 1.5) + }); + } + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(ids)), + Arc::new(StringArray::from(names)), + Arc::new(Float64Array::from(values)), + ], + )?; + batches.push(batch); + } + + // Distribute the batches round-robin across `num_partitions` MemTable + // partitions. With num_partitions=1 the example stays single-partition; + // larger values give the Spark connector's shared-scan mode real + // DataFusion-native partitions to map tasks onto. Partitions beyond the + // batch count stay empty — DataFusion handles empty partitions fine. + let mut partitions: Vec> = vec![Vec::new(); opts.num_partitions as usize]; + for (i, batch) in batches.into_iter().enumerate() { + partitions[i % opts.num_partitions as usize].push(batch); + } + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + +/// Build the example provider for one scan: decode the options blob, build +/// the `MemTable` accordingly. `partition` is unused — the example reports a +/// single partition (or relies on shared-scan mode), so there is no per-task +/// payload to interpret. +fn build_provider( + _ctx: &BridgeContext, + options: &[u8], + _partition: &[u8], +) -> JniResult> { + let opts = decode_options(options)?; + Ok(build_mem_table(&opts)?) +} + +export_bridge! { + jni_class: "org_apache_datafusion_examples_ExampleBridgeNative", + build_provider: build_provider, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn encode(prefix: &str, num_rows: u32, num_batches: u32) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&(prefix.len() as u32).to_le_bytes()); + buf.extend_from_slice(prefix.as_bytes()); + buf.extend_from_slice(&num_rows.to_le_bytes()); + buf.extend_from_slice(&num_batches.to_le_bytes()); + buf + } + + #[test] + fn empty_bytes_decodes_to_defaults() { + let o = decode_options(&[]).unwrap(); + assert_eq!(o.name_prefix, "row"); + assert_eq!(o.num_rows, 4); + assert_eq!(o.num_batches, 1); + assert_eq!(o.num_partitions, 1); + } + + #[test] + fn roundtrip_decodes_options() { + let o = decode_options(&encode("user", 5, 3)).unwrap(); + assert_eq!(o.name_prefix, "user"); + assert_eq!(o.num_rows, 5); + assert_eq!(o.num_batches, 3); + } + + #[test] + fn old_blob_without_trailing_fields_defaults_partitions_to_one() { + let o = decode_options(&encode("user", 5, 3)).unwrap(); + assert_eq!(o.num_partitions, 1); + } + + #[test] + fn trailing_fields_decode_num_partitions_and_ignore_flag_byte() { + let mut buf = encode("user", 5, 8); + buf.extend_from_slice(&4u32.to_le_bytes()); + buf.push(1); // shared_scan flag: JVM-side only + let o = decode_options(&buf).unwrap(); + assert_eq!(o.num_partitions, 4); + } + + #[test] + fn zero_partitions_rejected() { + let mut buf = encode("user", 5, 8); + buf.extend_from_slice(&0u32.to_le_bytes()); + buf.push(0); + assert!(decode_options(&buf).is_err()); + } + + #[test] + fn batches_distribute_round_robin_across_partitions() { + let opts = Options { + name_prefix: "u".to_string(), + num_rows: 2, + num_batches: 5, + num_partitions: 3, + }; + let table = build_mem_table(&opts).unwrap(); + // MemTable has no partition accessor; verify via scan output partitioning. + use datafusion::catalog::TableProvider; + use datafusion::prelude::SessionContext; + use tokio::runtime::Runtime; + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + let plan = rt + .block_on(async { table.scan(&ctx.state(), None, &[], None).await }) + .unwrap(); + assert_eq!(plan.properties().output_partitioning().partition_count(), 3); + } + + #[test] + fn build_table_has_expected_schema() { + let opts = Options { + name_prefix: "user".to_string(), + num_rows: 5, + num_batches: 3, + num_partitions: 1, + }; + let table = build_mem_table(&opts).unwrap(); + let schema = table.schema(); + assert_eq!(schema.fields().len(), 3); + assert_eq!(schema.field(0).name(), "id"); + assert_eq!(schema.field(1).name(), "name"); + assert_eq!(schema.field(2).name(), "value"); + } + + #[test] + fn rejects_zero_counts() { + let mut buf = Vec::new(); + buf.extend_from_slice(&3u32.to_le_bytes()); + buf.extend_from_slice(b"abc"); + buf.extend_from_slice(&0u32.to_le_bytes()); + buf.extend_from_slice(&1u32.to_le_bytes()); + assert!(decode_options(&buf).is_err()); + } +} diff --git a/examples/pom.xml b/examples/pom.xml index 78fcc5c..02888c3 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -37,6 +37,9 @@ under the License. true true + + debug @@ -44,6 +47,15 @@ under the License. org.apache.datafusion datafusion-java + + + org.apache.datafusion + datafusion-java-spark_2.13 + ${project.version} + provided + org.apache.arrow arrow-vector @@ -82,6 +94,97 @@ under the License. + + + org.apache.maven.plugins + maven-antrun-plugin + 3.1.0 + + + copy-example-bridge-cdylib + process-classes + run + + + + + + + + + + + + + + + + + + native-linux-amd64 + + unixlinuxamd64 + + + linux + x86_64 + libdatafusion_example_bridge.so + + + + native-linux-x86_64 + + unixlinuxx86_64 + + + linux + x86_64 + libdatafusion_example_bridge.so + + + + native-linux-aarch64 + + unixlinuxaarch64 + + + linux + aarch64 + libdatafusion_example_bridge.so + + + + native-mac-x86_64 + + macx86_64 + + + darwin + x86_64 + libdatafusion_example_bridge.dylib + + + + native-mac-aarch64 + + macaarch64 + + + darwin + aarch64 + libdatafusion_example_bridge.dylib + + + diff --git a/examples/python/README.md b/examples/python/README.md new file mode 100644 index 0000000..c9a335d --- /dev/null +++ b/examples/python/README.md @@ -0,0 +1,132 @@ +# PySpark end-to-end demo + +`bridge_demo.py` proves the full DataFusion → Spark path: + +``` +examples/native (export_bridge! cdylib) <-- in-memory MemTable + scan machinery + ^ byte[] options / FFI_ArrowArrayStream + | +ExampleBridgeProviderFactory <-- implements BridgeProviderFactory + | Class.forName(...) + v +datafusion-java-spark <-- DSv2 plumbing, predicate xlate + | spark.read.format("datafusion") + v +PySpark DataFrame <-- printSchema / show / filter / select +``` + +## Prerequisites + +1. **Java 17.** `JAVA_HOME` must point at a JDK 17 install. + +2. **The example bridge cdylib** built from this repo: + + ```bash + cargo build -p datafusion-java-example-bridge --release + ``` + +3. **Maven artifacts installed into a side-loaded local repository.** + + The script reads `arrow-c-data`, `flatbuffers-java`, and `protobuf-java` + jars from `${DATAFUSION_DEMO_M2:-/tmp/m2-datafusion}` (Spark's bundled + versions are too old, so the demo prepends our copies on + `spark.driver/executor.extraClassPath`). Tell Maven to install there: + + ```bash + mvn install -DskipTests \ + -Ddatafusion.native.profile=release \ + -Dmaven.repo.local=/tmp/m2-datafusion + ``` + + If you already use `~/.m2`, point `DATAFUSION_DEMO_M2` at it instead and + skip `-Dmaven.repo.local`. + +4. **A Scala 2.13 Spark distribution.** The PyPI `pyspark` wheel embeds + Scala 2.12 jars; the connector is compiled against 2.13, so we override + `SPARK_HOME` before importing pyspark. Download once: + + ```bash + cd /tmp + curl -L -o spark-2.13.tgz \ + https://archive.apache.org/dist/spark/spark-3.5.7/spark-3.5.7-bin-hadoop3-scala2.13.tgz + tar xzf spark-2.13.tgz + ``` + + The script defaults `SPARK_HOME` to + `/tmp/spark-3.5.7-bin-hadoop3-scala2.13`; set the env var if you put it + elsewhere. + +5. **A self-contained Python venv with `pyspark==3.5.7`** (uv keeps it + isolated from system site-packages): + + ```bash + cd examples/python + uv venv --python 3.11 .venv + uv pip install --python .venv/bin/python "pyspark==3.5.7" + cd ../.. + ``` + +## Run + +```bash +examples/python/.venv/bin/python examples/python/bridge_demo.py +``` + +Expected output: + +``` +=== schema === +root + |-- id: long (nullable = false) + |-- name: string (nullable = true) + |-- value: double (nullable = true) + +=== full scan === ++---+-----+-----+ +|id |name |value| ++---+-----+-----+ +|1 |alice|1.5 | +|2 |bob |2.5 | +|3 |NULL |3.5 | +|4 |dave |NULL | ++---+-----+-----+ + +=== filter pushdown: value > 2.0 === ++---+----+-----+ +|id |name|value| ++---+----+-----+ +|2 |bob |2.5 | +|3 |NULL|3.5 | ++---+----+-----+ + +=== projection: id, name === ++---+-----+ +|id |name | ++---+-----+ +|1 |alice| +|2 |bob | +|3 |NULL | +|4 |dave | ++---+-----+ +``` + +Filter row count drops from 4 → 2 because the predicate is pushed into the +bridge cdylib as a `LogicalExprNode` proto and applied inside DataFusion +before Arrow batches cross back to Spark. + +## Notes + +- `master("local[2]")` keeps driver + executor in one JVM so the example + cdylib loads once. In cluster mode nothing extra is needed: the bridge + cdylib travels inside the examples jar and `NativeLibraryLoader` extracts + it on every worker. +- `extraClassPath` (not `--packages` / `userClassPathFirst`) is used because + the Spark distro ships Arrow 12, flatbuffers 1.12, and protobuf 2.5, all + of which we need to override; userClassPathFirst splits Netty across two + class loaders and the `arrow-memory-netty-buffer-patch` shim breaks. +- The `datafusion` format short name resolves via the SPI file in + `spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister`. + You can also use the FQCN: `format("io.datafusion.spark.DatafusionSource")`. +- To swap in your own bridge, write a `BridgeProviderFactory` against your own + cdylib (mirroring `ExampleBridgeProviderFactory`) and pass its FQCN via + `option("df.factory", ...)`. diff --git a/examples/python/bridge_demo.py b/examples/python/bridge_demo.py new file mode 100644 index 0000000..a630224 --- /dev/null +++ b/examples/python/bridge_demo.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""End-to-end PySpark demo of a DataFusion table provider exposed as a Spark data source. + +Wires the in-memory example MemTable produced by ``examples/native`` into a +Spark DataSource V2 scan through the generic connector in ``spark/``. + +Prerequisites (run from the repo root): + + cargo build --release --workspace + mvn install -Ddatafusion.native.profile=release -DskipTests + +Run: + + python3 examples/python/bridge_demo.py +""" + +import glob +import os +import sys +from pathlib import Path + +# The PyPI ``pyspark`` wheel embeds a Scala 2.12 Spark distribution; this +# connector is compiled against Scala 2.13. Override SPARK_HOME (before the +# pyspark import so the wheel honours it) to a side-loaded 2.13 distribution. +_SPARK_HOME_2_13 = os.environ.get( + "SPARK_HOME", + "/tmp/spark-3.5.7-bin-hadoop3-scala2.13", +) +if not Path(_SPARK_HOME_2_13, "jars", "scala-library-2.13.8.jar").exists(): + sys.exit( + f"missing Scala 2.13 Spark distribution at {_SPARK_HOME_2_13}. " + "Download from https://archive.apache.org/dist/spark/spark-3.5.7/" + "spark-3.5.7-bin-hadoop3-scala2.13.tgz and extract to that path " + "(or set SPARK_HOME to your own 2.13 distro)." + ) +os.environ["SPARK_HOME"] = _SPARK_HOME_2_13 + +from pyspark.sql import SparkSession + + +REPO_ROOT = Path(__file__).resolve().parents[2] +VERSION = "0.2.0-SNAPSHOT" +ARROW_VERSION = "19.0.0" +FLATBUFFERS_VERSION = "25.2.10" +PROTOBUF_VERSION = "3.25.5" +# Local maven repository populated by ``mvn install -Dmaven.repo.local=...``. +M2_REPO = Path(os.environ.get("DATAFUSION_DEMO_M2", "/tmp/m2-datafusion")) + + +def _resolve_jar(module: str, artifact: str) -> str: + candidates = glob.glob(str(REPO_ROOT / module / "target" / f"{artifact}-{VERSION}.jar")) + if not candidates: + sys.exit( + f"missing jar for {artifact} under {module}/target/. " + f"Run 'mvn install -DskipTests' from {REPO_ROOT} first." + ) + return candidates[0] + + +def _m2_jar(group_path: str, artifact: str, version: str) -> str: + path = M2_REPO / group_path / artifact / version / f"{artifact}-{version}.jar" + if not path.exists(): + sys.exit( + f"missing dependency jar {path}. " + f"Re-run 'mvn install -DskipTests -Dmaven.repo.local={M2_REPO}'." + ) + return str(path) + + +def main() -> None: + # Spark 3.5.7 bundles Arrow 12.0.1; datafusion-java is compiled against + # Arrow 19, which needs ArrowArrayStream (added after 12) and a much newer + # flatbuffers runtime. Ship our copies on spark.jars and force userClassPathFirst + # so they win over the bundled jars on both driver and executor. + arrow_jars = [ + _m2_jar("org/apache/arrow", "arrow-format", ARROW_VERSION), + _m2_jar("org/apache/arrow", "arrow-vector", ARROW_VERSION), + _m2_jar("org/apache/arrow", "arrow-memory-core", ARROW_VERSION), + _m2_jar("org/apache/arrow", "arrow-memory-netty", ARROW_VERSION), + _m2_jar( + "org/apache/arrow", + "arrow-memory-netty-buffer-patch", + ARROW_VERSION, + ), + _m2_jar("org/apache/arrow", "arrow-c-data", ARROW_VERSION), + _m2_jar( + "com/google/flatbuffers", "flatbuffers-java", FLATBUFFERS_VERSION + ), + # Spark ships protobuf-java 2.5.0 (sans MessageOrBuilder). The proto + # surface in core (LogicalExprNode etc.) needs 3.25.x. + _m2_jar("com/google/protobuf", "protobuf-java", PROTOBUF_VERSION), + ] + app_jars = [ + _resolve_jar("core", "datafusion-java"), + _resolve_jar("spark", "datafusion-java-spark_2.13"), + _resolve_jar("examples", "datafusion-java-examples"), + *arrow_jars, + ] + jars = ",".join(app_jars) + # Prepend the same jars onto the bootstrap classpath so Arrow 19's classes + # are loaded by the system class loader — avoids the + # ``UnsafeDirectLittleEndian cannot access superclass WrappedByteBuf`` + # IllegalAccessError that ChildFirstURLClassLoader produces when the + # buffer-patch class lands in the child loader while Netty stays in the app + # loader. + extra_classpath = ":".join(app_jars) + + spark = ( + SparkSession.builder.appName("datafusion-bridge-demo") + .master("local[2]") + .config("spark.jars", jars) + .config("spark.driver.extraClassPath", extra_classpath) + .config("spark.executor.extraClassPath", extra_classpath) + .config( + "spark.driver.extraJavaOptions", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + ) + .config( + "spark.executor.extraJavaOptions", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + ) + .getOrCreate() + ) + + # The example cdylib is bundled inside the examples jar and extracted by + # NativeLibraryLoader at first use; no working-directory or path setup is + # needed. (-Dexample.bridge.lib.path via extraJavaOptions overrides it for + # unpackaged local builds.) + + # `name_prefix`, `num_rows`, `num_batches` are interpreted by + # ExampleBridgeProviderFactory.encodeOptions and decoded on the Rust side + # in examples/native/src/lib.rs. They demonstrate driver-side options + # flowing through to the native MemTable build. + name_prefix = "user" + num_rows = 5 + num_batches = 3 + df = ( + spark.read.format("datafusion") + .option( + "df.factory", + "org.apache.datafusion.examples.ExampleBridgeProviderFactory", + ) + .option("name_prefix", name_prefix) + .option("num_rows", str(num_rows)) + .option("num_batches", str(num_batches)) + .load() + ) + + total_rows = num_rows * num_batches + print(f"=== options: name_prefix={name_prefix} num_rows={num_rows} num_batches={num_batches} ===") + print(f"=== expecting {total_rows} rows across {num_batches} Arrow batches ===") + + print("=== schema ===") + df.printSchema() + + print(f"=== full scan (first {total_rows} rows) ===") + df.show(n=total_rows, truncate=False) + + print("=== filter pushdown: value > 5.0 ===") + df.filter("value > 5.0").show(n=total_rows, truncate=False) + + print("=== projection: id, name ===") + df.select("id", "name").show(n=total_rows, truncate=False) + + legacy_rows = {tuple(r) for r in df.collect()} + + # --- shared-scan mode ------------------------------------------------- + # `shared_scan=true` flips ExampleBridgeProviderFactory.sharedScan: one + # provider + plan cached per executor, one Spark task per MemTable + # partition (num_partitions=4), each task streaming one DataFusion plan + # partition. Results must be identical to the legacy run above. + num_partitions = 4 + shared = ( + spark.read.format("datafusion") + .option( + "df.factory", + "org.apache.datafusion.examples.ExampleBridgeProviderFactory", + ) + .option("name_prefix", name_prefix) + .option("num_rows", str(num_rows)) + .option("num_batches", str(num_batches)) + .option("num_partitions", str(num_partitions)) + .option("shared_scan", "true") + .load() + ) + + print(f"=== shared-scan mode: num_partitions={num_partitions} ===") + shared_partitions = shared.rdd.getNumPartitions() + print(f"=== shared-scan Spark partitions: {shared_partitions} ===") + assert shared_partitions == num_partitions, ( + f"expected {num_partitions} Spark partitions in shared-scan mode, " + f"got {shared_partitions}" + ) + + shared.show(n=total_rows, truncate=False) + shared_rows = {tuple(r) for r in shared.collect()} + assert shared_rows == legacy_rows, ( + "shared-scan rows diverge from legacy mode: " + f"only-legacy={legacy_rows - shared_rows} only-shared={shared_rows - legacy_rows}" + ) + print(f"=== shared-scan returned the same {len(shared_rows)} rows as legacy mode ===") + + print("=== shared-scan filter pushdown: value > 5.0 ===") + shared.filter("value > 5.0").show(n=total_rows, truncate=False) + + # Note on cache scope: the executor cache is keyed by a per-query scanId, + # so sharing happens across the TASKS of one query (4 tasks above -> one + # provider build per executor JVM, in the bridge's native build_provider), + # not across separate actions. Each new + # action plans a new scan with a fresh scanId; its entry simply joins the + # cache until the idle TTL evicts it. + count_again = shared.count() + assert count_again == total_rows, f"expected {total_rows} rows, got {count_again}" + print("=== shared-scan count() as a separate action also succeeded ===") + + spark.stop() + + +if __name__ == "__main__": + main() diff --git a/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeNative.java b/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeNative.java new file mode 100644 index 0000000..dff42ee --- /dev/null +++ b/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeNative.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.examples; + +import io.datafusion.spark.NativeLibraryLoader; + +/** + * JNI surface generated on the Rust side by {@code export_bridge!} in {@code + * examples/native/src/lib.rs} with {@code jni_class = + * "org_apache_datafusion_examples_ExampleBridgeNative"} — the mangled binary name of THIS class. + * Renaming or moving this class requires regenerating the Rust macro invocation to match. + * + *

The cdylib is bundled inside this jar at {@code org/apache/datafusion/examples///} + * (see the antrun execution in {@code examples/pom.xml}). For local hacking against an unpackaged + * build, {@code -Dexample.bridge.lib.path=/abs/path/to/libdatafusion_example_bridge.dylib} bypasses + * the bundled copy. + */ +final class ExampleBridgeNative { + + private ExampleBridgeNative() {} + + static { + String explicit = System.getProperty("example.bridge.lib.path"); + if (explicit != null && !explicit.isEmpty()) { + System.load(explicit); + } else { + NativeLibraryLoader.load( + ExampleBridgeNative.class, "org/apache/datafusion/examples", "datafusion_example_bridge"); + } + } + + static native byte[] providerSchemaIpc(byte[] options, byte[] partition); + + static native long createScan( + byte[] options, + byte[] partition, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos); + + static native int partitionCount(long scanHandle); + + static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + + static native void executeStream(long scanHandle, long ffiStreamAddr); + + static native void closeScan(long scanHandle); +} diff --git a/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeProviderFactory.java b/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeProviderFactory.java new file mode 100644 index 0000000..5b4c921 --- /dev/null +++ b/examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeProviderFactory.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.examples; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import io.datafusion.spark.BridgeProviderFactory; +import io.datafusion.spark.PartitionInfo; +import io.datafusion.spark.ScanBackend; + +/** + * Minimal {@link BridgeProviderFactory} that exposes the example {@code MemTable} built inside the + * example bridge cdylib (see {@code examples/native}) as a Spark DataSource V2 source. + * + *

Wire it into PySpark with: + * + *

{@code
+ * df = (spark.read.format("datafusion")
+ *         .option("df.factory", "org.apache.datafusion.examples.ExampleBridgeProviderFactory")
+ *         .option("name_prefix", "user")
+ *         .option("num_rows", "5")
+ *         .option("num_batches", "3")
+ *         .load())
+ * }
+ * + *

Supported options (all optional): + * + *

    + *
  • {@code name_prefix} — prefix string used for generated {@code name} column values. Default + * {@code "row"}. + *
  • {@code num_rows} — rows per batch. Default {@code 4}. + *
  • {@code num_batches} — number of in-memory {@code RecordBatch}es composing the table. + * Default {@code 1}. + *
  • {@code num_partitions} — number of DataFusion-native MemTable partitions the batches are + * distributed across (round-robin). Default {@code 1}. Mostly interesting together with + * {@code shared_scan}. + *
  • {@code shared_scan} — {@code true} opts into the connector's shared-scan mode: one cached + * provider + plan per executor, one Spark task per MemTable partition. Default {@code false} + * (single task via {@link #listPartitions(byte[])}). + *
+ * + *

Real bridges (HDF5, custom Iceberg, in-house formats) use a protobuf schema for {@code + * optionsBytes}; this example uses a hand-rolled length-prefixed binary format to keep the + * wire layer obvious: + * + *

+ *   [u32 LE name_prefix_len][name_prefix UTF-8 bytes][u32 LE num_rows][u32 LE num_batches]
+ *       [u32 LE num_partitions][u8 shared_scan]
+ * 
+ * + *

An empty {@code byte[]} is also accepted by the native side and decoded as all defaults; the + * two trailing fields are optional so older blobs keep decoding. + * + *

In the default mode a single partition (id {@code "p0"}, empty {@code partitionBytes}, no + * preferred host) is reported so Spark spawns one task; the executor hands the options bytes to + * {@code ExampleBridgeNative.createScan}, which builds the {@code MemTable} provider in process and + * streams the resulting Arrow record batches back into the Spark scan. + */ +public final class ExampleBridgeProviderFactory implements BridgeProviderFactory { + + static final String OPT_NAME_PREFIX = "name_prefix"; + static final String OPT_NUM_ROWS = "num_rows"; + static final String OPT_NUM_BATCHES = "num_batches"; + static final String OPT_NUM_PARTITIONS = "num_partitions"; + static final String OPT_SHARED_SCAN = "shared_scan"; + + static final String DEFAULT_NAME_PREFIX = "row"; + static final int DEFAULT_NUM_ROWS = 4; + static final int DEFAULT_NUM_BATCHES = 1; + static final int DEFAULT_NUM_PARTITIONS = 1; + + public ExampleBridgeProviderFactory() {} + + @Override + public byte[] encodeOptions(Map sparkOptions) { + String namePrefix = sparkOptions.getOrDefault(OPT_NAME_PREFIX, DEFAULT_NAME_PREFIX); + int numRows = parsePositiveInt(sparkOptions, OPT_NUM_ROWS, DEFAULT_NUM_ROWS); + int numBatches = parsePositiveInt(sparkOptions, OPT_NUM_BATCHES, DEFAULT_NUM_BATCHES); + int numPartitions = parsePositiveInt(sparkOptions, OPT_NUM_PARTITIONS, DEFAULT_NUM_PARTITIONS); + boolean sharedScan = Boolean.parseBoolean(sparkOptions.getOrDefault(OPT_SHARED_SCAN, "false")); + + byte[] nameBytes = namePrefix.getBytes(StandardCharsets.UTF_8); + ByteBuffer buf = + ByteBuffer.allocate(4 + nameBytes.length + 4 + 4 + 4 + 1).order(ByteOrder.LITTLE_ENDIAN); + buf.putInt(nameBytes.length); + buf.put(nameBytes); + buf.putInt(numRows); + buf.putInt(numBatches); + buf.putInt(numPartitions); + buf.put((byte) (sharedScan ? 1 : 0)); + return buf.array(); + } + + @Override + public PartitionInfo[] listPartitions(byte[] optionsBytes) { + // Single partition; the example MemTable is not actually sliced. A real bridge would + // populate `partitionBytes` per slice and `preferredLocations` with the hosts holding it. + return new PartitionInfo[] {new PartitionInfo("p0", new byte[0], new String[0])}; + } + + @Override + public PartitionInfo[] listPartitions(byte[] optionsBytes, byte[][] filterProtoBytes) { + // The example cannot prune its single partition, but a real bridge would inspect the + // pushed predicates here and drop partitions that cannot match. + System.out.println( + "ExampleBridgeProviderFactory.listPartitions received " + + filterProtoBytes.length + + " pushed filter(s)"); + return listPartitions(optionsBytes); + } + + @Override + public boolean sharedScan(byte[] optionsBytes) { + // The flag is the final byte of the options blob (present only when the encoder wrote the + // trailing fields). The bridge owns its wire format, so decoding it here is fair game. + return optionsBytes != null + && optionsBytes.length >= 1 + && hasTrailingFields(optionsBytes) + && optionsBytes[optionsBytes.length - 1] == 1; + } + + private static boolean hasTrailingFields(byte[] bytes) { + if (bytes.length < 4) { + return false; + } + int nameLen = ByteBuffer.wrap(bytes, 0, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(); + // base layout: 4 (len) + name + 4 (num_rows) + 4 (num_batches); trailing adds 4 + 1. + return bytes.length >= 4 + nameLen + 8 + 5; + } + + @Override + public ScanBackend scanBackend() { + return new ExampleScanBackend(); + } + + private static int parsePositiveInt(Map opts, String key, int defaultValue) { + String raw = opts.get(key); + if (raw == null || raw.isEmpty()) { + return defaultValue; + } + int parsed; + try { + parsed = Integer.parseInt(raw.trim()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "ExampleBridgeProviderFactory: option '" + key + "' must be an integer, got: " + raw); + } + if (parsed <= 0) { + throw new IllegalArgumentException( + "ExampleBridgeProviderFactory: option '" + key + "' must be > 0, got: " + parsed); + } + return parsed; + } +} diff --git a/examples/src/main/java/org/apache/datafusion/examples/ExampleScanBackend.java b/examples/src/main/java/org/apache/datafusion/examples/ExampleScanBackend.java new file mode 100644 index 0000000..9854817 --- /dev/null +++ b/examples/src/main/java/org/apache/datafusion/examples/ExampleScanBackend.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.examples; + +import io.datafusion.spark.ScanBackend; + +/** Routes the connector's scan calls to the example bridge cdylib. Pure delegation. */ +final class ExampleScanBackend implements ScanBackend { + + @Override + public byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes) { + return ExampleBridgeNative.providerSchemaIpc(options, partitionBytes); + } + + @Override + public long createScan( + byte[] options, + byte[] partitionBytes, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos) { + return ExampleBridgeNative.createScan( + options, + partitionBytes, + targetPartitions, + batchSize, + optionKeys, + optionValues, + projectionColumns, + filterProtos); + } + + @Override + public int partitionCount(long scanHandle) { + return ExampleBridgeNative.partitionCount(scanHandle); + } + + @Override + public void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr) { + ExampleBridgeNative.executeStreamPartition(scanHandle, partition, ffiStreamAddr); + } + + @Override + public void executeStream(long scanHandle, long ffiStreamAddr) { + ExampleBridgeNative.executeStream(scanHandle, ffiStreamAddr); + } + + @Override + public void closeScan(long scanHandle) { + ExampleBridgeNative.closeScan(scanHandle); + } +} diff --git a/native-common/Cargo.toml b/native-common/Cargo.toml new file mode 100644 index 0000000..0a797b4 --- /dev/null +++ b/native-common/Cargo.toml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-jni-common" +version = "0.1.0" +edition = "2021" +publish = false + +[features] +# `datafusion-jni` builds DataFusion with `avro`, which adds the +# `DataFusionError::AvroError` variant our classifier maps to IoException. +# Feature-forwarded so consumers that don't read Avro (the Spark helper) +# don't pull the apache-avro stack into their cdylib. +avro = ["datafusion/avro"] + +[dependencies] +datafusion = { workspace = true } +futures = { workspace = true } +jni = { workspace = true } +tokio = { workspace = true } diff --git a/native/src/errors.rs b/native-common/src/errors.rs similarity index 97% rename from native/src/errors.rs rename to native-common/src/errors.rs index d926544..caa2540 100644 --- a/native/src/errors.rs +++ b/native-common/src/errors.rs @@ -96,8 +96,11 @@ fn classify(err: &DataFusionError) -> &'static str { } DataFusionError::IoError(_) | DataFusionError::ObjectStore(_) - | DataFusionError::ParquetError(_) - | DataFusionError::AvroError(_) => "org/apache/datafusion/IoException", + | DataFusionError::ParquetError(_) => "org/apache/datafusion/IoException", + // The AvroError variant only exists when DataFusion is built with its + // `avro` feature, forwarded by this crate's own `avro` feature. + #[cfg(feature = "avro")] + DataFusionError::AvroError(_) => "org/apache/datafusion/IoException", // ArrowError is a 21-variant grab bag -- only some of those variants // are actually IO-shaped. DivideByZero / ArithmeticOverflow / Compute // / Cast / InvalidArgument / Memory etc. are execution-time failures diff --git a/native-common/src/lib.rs b/native-common/src/lib.rs new file mode 100644 index 0000000..f143d43 --- /dev/null +++ b/native-common/src/lib.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! JNI plumbing shared by this workspace's native crates (`datafusion-jni` +//! and `datafusion-spark-bridge`, and through the latter every bridge +//! cdylib): the error-to-Java-exception mapping, the per-cdylib Tokio +//! runtime singleton, and the async-stream-to-`FFI_ArrowArrayStream` +//! bridge. +//! +//! Each cdylib statically links its own copy of this rlib, so [`runtime`] is +//! a per-cdylib singleton -- exactly the behaviour each crate had when this +//! code lived inline. Nothing here is exported with `#[no_mangle]`, so +//! linking this crate into several cdylibs loaded in one JVM cannot collide. + +pub mod errors; + +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::OnceLock; + +use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::ArrowError; +use datafusion::arrow::record_batch::RecordBatchReader; +use datafusion::execution::SendableRecordBatchStream; +use futures::StreamExt; +use tokio::runtime::{Handle, Runtime}; + +static RT: OnceLock = OnceLock::new(); + +/// The cdylib-wide Tokio runtime. +pub fn runtime() -> &'static Runtime { + runtime_with_init(|_| {}) +} + +/// Same singleton as [`runtime`], with a hook that runs exactly once, when +/// the runtime is created. `datafusion-jni` uses it to install its +/// runtime-metrics accumulator so the sampling baseline coincides with +/// runtime start; every later call (either entry point) returns the existing +/// runtime without invoking the hook. +pub fn runtime_with_init(init: impl FnOnce(&Handle)) -> &'static Runtime { + RT.get_or_init(|| { + let rt = Runtime::new().expect("failed to create Tokio runtime"); + init(rt.handle()); + rt + }) +} + +/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous +/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore +/// the Java `ArrowReader`) consumes. Each call to `next()` drives one +/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the +/// executor pipeline plus a single in-flight batch. +pub struct StreamingReader { + pub schema: SchemaRef, + pub stream: SendableRecordBatchStream, +} + +impl Iterator for StreamingReader { + type Item = Result; + + fn next(&mut self) -> Option { + // Arrow's C ABI invokes this iterator through FFI_ArrowArrayStream's + // vtable, outside the JNI handler's try_unwrap_or_throw guard. A panic + // here (buggy UDF, arrow cast that panics, runtime poison) would + // unwind across C/FFI -- undefined behaviour. Catch it and surface as + // an ArrowError so the Java side sees a normal exception instead. + let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next()))); + match next { + Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))), + Err(panic) => { + let msg = if let Some(s) = panic.downcast_ref::() { + s.clone() + } else if let Some(s) = panic.downcast_ref::<&str>() { + (*s).to_string() + } else { + "rust panic with non-string payload".to_string() + }; + Some(Err(ArrowError::ExternalError( + format!("panic in DataFrame stream: {msg}").into(), + ))) + } + } + } +} + +impl RecordBatchReader for StreamingReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/native/Cargo.toml b/native/Cargo.toml index c462408..0f4ca83 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -23,8 +23,8 @@ publish = false [lib] # `rlib` alongside `cdylib` so `cargo test` has a Rust-level harness for -# native-only invariants (e.g. error-classification routing through wrapped -# DataFusionError chains). The `cdylib` is still the artifact the JVM loads. +# native-only invariants (the error-classification tests now live in +# `datafusion-jni-common`). The `cdylib` is still the artifact the JVM loads. crate-type = ["cdylib", "rlib"] [features] @@ -69,24 +69,23 @@ protoc = ["datafusion-substrait?/protoc"] runtime-metrics = ["dep:tokio-metrics"] [dependencies] -arrow = { version = "58", features = ["ffi"] } -async-trait = "0.1" -datafusion = { version = "53.1.0", features = ["avro"] } -datafusion-proto = "53.1.0" -datafusion-substrait = { version = "53.1.0", optional = true } -futures = "0.3" -jni = "0.21" -# Pin to the same major as DataFusion 53.1 pulls in transitively (0.13.x) -# so we share the same `dyn ObjectStore` vtable and don't double-link. -object_store = { version = "0.13", default-features = false } -prost = "0.14" -tokio = { version = "1", features = ["rt-multi-thread"] } -# Tokio runtime metrics. Optional + cfg-gated: this crate's API surface lives -# behind `--cfg tokio_unstable`, so enabling the `runtime-metrics` feature also -# requires the caller to set `RUSTFLAGS="--cfg tokio_unstable"` at build time. -tokio-metrics = { version = "0.5", optional = true } -url = "2" +arrow = { workspace = true } +async-trait = { workspace = true } +datafusion = { workspace = true, features = ["avro"] } +# Shared JNI plumbing (error->exception mapping, runtime singleton, +# StreamingReader). `avro` keeps the classifier's AvroError->IoException arm +# in sync with the `avro` feature on `datafusion` above. +datafusion-jni-common = { path = "../native-common", features = ["avro"] } +datafusion-proto = { workspace = true } +datafusion-substrait = { workspace = true, optional = true } +futures = { workspace = true } +jni = { workspace = true } +object_store = { workspace = true } +prost = { workspace = true } +tokio = { workspace = true } +tokio-metrics = { workspace = true, optional = true } +url = { workspace = true } [build-dependencies] -prost-build = "0.14" -protoc-bin-vendored = "3" +prost-build = { workspace = true } +protoc-bin-vendored = { workspace = true } diff --git a/native/src/arrow.rs b/native/src/arrow.rs index 2bbe7b0..67e5caf 100644 --- a/native/src/arrow.rs +++ b/native/src/arrow.rs @@ -23,10 +23,10 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::ArrowReadOptionsProto; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_arrow_options( env: &mut JNIEnv, diff --git a/native/src/avro.rs b/native/src/avro.rs index 85d4a07..257ae32 100644 --- a/native/src/avro.rs +++ b/native/src/avro.rs @@ -23,10 +23,10 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::AvroReadOptionsProto; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_avro_options( env: &mut JNIEnv, diff --git a/native/src/cache_manager.rs b/native/src/cache_manager.rs index 3b9e286..ec38dc8 100644 --- a/native/src/cache_manager.rs +++ b/native/src/cache_manager.rs @@ -34,8 +34,8 @@ use datafusion::execution::cache::cache_unit::{ }; use datafusion::execution::cache::DefaultListFilesCache; -use crate::errors::JniResult; use crate::proto_gen::CacheManagerOptionsProto; +use datafusion_jni_common::errors::JniResult; /// Build a [`CacheManagerConfig`] from the proto. Returns `Ok(None)` if the /// caller did not set any cache-manager field, so the JNI layer can skip the diff --git a/native/src/csv.rs b/native/src/csv.rs index 3ae4627..b79ed59 100644 --- a/native/src/csv.rs +++ b/native/src/csv.rs @@ -26,12 +26,12 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::{ CsvReadOptionsProto, CsvWriteOptionsProto, FileCompressionType as ProtoFileCompressionType, }; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_csv_options( env: &mut JNIEnv, diff --git a/native/src/json.rs b/native/src/json.rs index 8eea32f..b87be78 100644 --- a/native/src/json.rs +++ b/native/src/json.rs @@ -27,12 +27,12 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::{ FileCompressionType as ProtoFileCompressionType, JsonWriteOptionsProto, NdJsonReadOptionsProto, }; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_json_options( env: &mut JNIEnv, diff --git a/native/src/lib.rs b/native/src/lib.rs index 4fd7a8a..6e1a79f 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -19,7 +19,6 @@ mod arrow; mod avro; mod cache_manager; mod csv; -mod errors; mod jni_util; mod json; mod memory; @@ -34,16 +33,13 @@ pub(crate) mod proto_gen { include!(concat!(env!("OUT_DIR"), "/datafusion_java.rs")); } -use std::panic::{catch_unwind, AssertUnwindSafe}; use std::path::PathBuf; use std::sync::{Arc, OnceLock}; -use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::arrow::error::ArrowError; use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; use datafusion::arrow::ipc::writer::StreamWriter; -use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; +use datafusion::arrow::record_batch::RecordBatchIterator; use datafusion::common::{JoinType, UnnestOptions}; use datafusion::config::TableParquetOptions; use datafusion::dataframe::DataFrame; @@ -51,11 +47,9 @@ use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::DataFusionError; use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SendableRecordBatchStream; use datafusion::logical_expr::Expr; use datafusion::logical_expr::{col, Partitioning, ScalarUDF, Signature, SortExpr}; use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; -use futures::StreamExt; use jni::objects::{JBooleanArray, JByteArray, JClass, JObject, JObjectArray, JString}; use jni::sys::{jboolean, jbyte, jbyteArray, jint, jlong}; use jni::JNIEnv; @@ -63,7 +57,10 @@ use jni::JavaVM; use prost::Message; use tokio::runtime::Runtime; -use crate::errors::{try_unwrap_or_throw, JniResult}; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; +// Re-exported so sibling modules keep their crate-local `crate::StreamingReader` path. +pub(crate) use datafusion_jni_common::StreamingReader; + use crate::proto_gen::ParquetReadOptionsProto; use crate::proto_gen::SessionOptions; use crate::schema::decode_optional_schema; @@ -84,18 +81,15 @@ pub(crate) fn jvm() -> &'static JavaVM { } pub(crate) fn runtime() -> &'static Runtime { - static RT: OnceLock = OnceLock::new(); - RT.get_or_init(|| { - let rt = Runtime::new().expect("failed to create Tokio runtime"); - // Eagerly install the runtime-metrics accumulator (no-op when the - // `runtime-metrics` Cargo feature is off). Initialising here -- not - // lazily on the first `runtimeStats()` call -- means the - // RuntimeMonitor's sampling baseline coincides with runtime start, so - // poll/park/busy totals reflect activity from the first query onward - // rather than from the first observation. - crate::runtime_metrics::init(rt.handle()); - rt - }) + // The singleton itself lives in datafusion-jni-common (shared with the + // datafusion-spark-bridge SDK; each cdylib statically links its own + // copy, so the runtime stays per-library). The init hook eagerly installs the + // runtime-metrics accumulator (no-op when the `runtime-metrics` Cargo + // feature is off). Initialising here -- not lazily on the first + // `runtimeStats()` call -- means the RuntimeMonitor's sampling baseline + // coincides with runtime start, so poll/park/busy totals reflect activity + // from the first query onward rather than from the first observation. + datafusion_jni_common::runtime_with_init(crate::runtime_metrics::init) } /// Wrap the (already-built) `RuntimeEnvBuilder`'s memory pool with a @@ -289,50 +283,6 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo }) } -/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous -/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore -/// the Java `ArrowReader`) consumes. Each call to `next()` drives one -/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the -/// executor pipeline plus a single in-flight batch. -struct StreamingReader { - schema: SchemaRef, - stream: SendableRecordBatchStream, -} - -impl Iterator for StreamingReader { - type Item = Result; - - fn next(&mut self) -> Option { - // Arrow's C ABI invokes this iterator through FFI_ArrowArrayStream's - // vtable, outside the JNI handler's try_unwrap_or_throw guard. A panic - // here (buggy UDF, arrow cast that panics, runtime poison) would - // unwind across C/FFI -- undefined behaviour. Catch it and surface as - // an ArrowError so the Java side sees a normal exception instead. - let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next()))); - match next { - Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))), - Err(panic) => { - let msg = if let Some(s) = panic.downcast_ref::() { - s.clone() - } else if let Some(s) = panic.downcast_ref::<&str>() { - (*s).to_string() - } else { - "rust panic with non-string payload".to_string() - }; - Some(Err(ArrowError::ExternalError( - format!("panic in DataFrame stream: {msg}").into(), - ))) - } - } - } -} - -impl RecordBatchReader for StreamingReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFrame<'local>( mut env: JNIEnv<'local>, diff --git a/native/src/object_store.rs b/native/src/object_store.rs index eefccf2..985d721 100644 --- a/native/src/object_store.rs +++ b/native/src/object_store.rs @@ -28,9 +28,9 @@ use std::sync::Arc; use datafusion::prelude::SessionContext; use url::Url; -use crate::errors::JniResult; use crate::proto_gen::object_store_registration::Backend; use crate::proto_gen::ObjectStoreRegistration; +use datafusion_jni_common::errors::JniResult; #[cfg(feature = "object-store-gcp")] use crate::proto_gen::GcsOptions; diff --git a/native/src/proto.rs b/native/src/proto.rs index 4f187bc..c1315f9 100644 --- a/native/src/proto.rs +++ b/native/src/proto.rs @@ -28,8 +28,8 @@ use jni::sys::{jbyteArray, jlong}; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::runtime; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_SessionContext_createDataFrameFromProto< diff --git a/native/src/runtime_metrics.rs b/native/src/runtime_metrics.rs index e69410e..dd60dcb 100644 --- a/native/src/runtime_metrics.rs +++ b/native/src/runtime_metrics.rs @@ -38,7 +38,7 @@ //! 10 totalOverflowCount #[cfg(not(feature = "runtime-metrics"))] -use crate::errors::JniResult; +use datafusion_jni_common::errors::JniResult; /// Number of i64 values in the snapshot array; kept here so the Java side and /// the feature-off stub agree on the layout. @@ -51,7 +51,7 @@ mod imp { use tokio_metrics::{RuntimeIntervals, RuntimeMonitor}; use super::STATS_FIELD_COUNT; - use crate::errors::JniResult; + use datafusion_jni_common::errors::JniResult; /// `RuntimeMonitor::intervals().next()` returns *delta* metrics covering /// the period since the previous call (or, on the very first call, since @@ -196,7 +196,7 @@ pub fn runtime_stats() -> JniResult<[i64; STATS_FIELD_COUNT]> { Err( "datafusion-jni was built without the `runtime-metrics` Cargo feature; \ rebuild the native crate with \ - `RUSTFLAGS=\"--cfg tokio_unstable\" cargo build --features runtime-metrics` \ + `RUSTFLAGS=\"--cfg tokio_unstable\" cargo build -p datafusion-jni --features runtime-metrics` \ to enable SessionContext.runtimeStats" .into(), ) diff --git a/native/src/schema.rs b/native/src/schema.rs index 968a73a..0c3c7ab 100644 --- a/native/src/schema.rs +++ b/native/src/schema.rs @@ -20,7 +20,7 @@ use datafusion::arrow::ipc::reader::StreamReader; use jni::objects::JByteArray; use jni::JNIEnv; -use crate::errors::JniResult; +use datafusion_jni_common::errors::JniResult; /// Decode an optional Arrow-IPC schema byte array passed in from Java. /// Returns `None` if the byte-array reference is null. diff --git a/pom.xml b/pom.xml index 6210841..6baeb94 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ under the License. core + spark examples @@ -95,6 +96,11 @@ under the License. + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + org.apache.maven.plugins maven-surefire-plugin @@ -159,6 +165,7 @@ under the License. README.md CONTRIBUTING.md docs/** + **/*.md .gitignore .idea/** @@ -173,12 +180,17 @@ under the License. .mvn/** **/target/** - native/target/** + rust-target/** tpch-data/** - - native/Cargo.lock + + Cargo.lock + + **/META-INF/services/** dev/release/rat_exclude_files.txt + + spark/scaffold/bridge-template/** diff --git a/spark/README.md b/spark/README.md new file mode 100644 index 0000000..5cc3d3c --- /dev/null +++ b/spark/README.md @@ -0,0 +1,454 @@ +# DataFusion Spark Connector + +This module (`datafusion-java-spark`) lets you expose a [DataFusion +`TableProvider`](https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html) +written in Rust as an [Apache Spark DataSource +V2](https://spark.apache.org/docs/latest/sql-data-sources.html) table. If you +have data that DataFusion can already read — an in-house file format, a custom +catalog, a remote service — this connector is the bridge that makes +`spark.read.format(...)` work against it, with predicate pushdown, column +pruning, and partitioned parallel reads. + +You write two small pieces (a Rust function and a Java class); the connector +supplies everything else. + +## How it fits together + +Two layers, one of which already exists: + +``` + your bridge (you write this) this module (already written) ++--------------------------------+ +----------------------------------+ +| cdylib on datafusion-spark- | | Scala/Java DSv2 plumbing | +| bridge (spark/bridge SDK): | | (spark/src) schema inference, | +| your TableProvider + one |<--| pushdown, task planning, | +| export_bridge! invocation; |-->| shared-scan cache | +| the SDK supplies widening, | | | +| session, filters, planning, | | (pure JVM — all native code | +| partition streams | | ships inside YOUR jar) | ++--------------------------------+ +----------------------------------+ + | + v + spark.read.format("...").load() +``` + +The only things that cross between the JVM and your cdylib are opaque +`byte[]` blobs that *you* define (options and per-partition payloads — the +connector never inspects them) going in, and Arrow C streams coming back. +Everything DataFusion-side (planning, filter application, execution) happens +inside your bridge's native library. There is no DataFusion session on the +JVM side at all, and no `FFI_TableProvider` boundary anywhere — your +concrete provider is linked into the same cdylib as the scan machinery. + +## Getting started: generate a bridge + +Don't hand-assemble the pieces below — stamp them out: + +```bash +python3 spark/scaffold/new_bridge.py --name acme --package com.example.acme +``` + +generates a standalone project (Rust cdylib with a working demo provider, +the four Java classes, service registration, shaded-jar pom with the cdylib +bundled, pyspark smoke test, README with the build commands). Replace the +demo `MemTable` in its `native/src/lib.rs` and you have a connector. The +sections below explain what each generated piece is for. + +## What you implement + +| # | Piece | Language | Contract lives at | Working example | +|---|-------|----------|-------------------|-----------------| +| 1 | A provider builder + one `export_bridge!` invocation | Rust | [`bridge/src/lib.rs`](bridge/src/lib.rs) (macro rustdoc) | [`examples/native/src/lib.rs`](../examples/native/src/lib.rs) | +| 2 | A `BridgeProviderFactory` implementation (one required method) + the JNI/backend boilerplate | Java | [`src/main/java/io/datafusion/spark/BridgeProviderFactory.java`](src/main/java/io/datafusion/spark/BridgeProviderFactory.java) | [`examples/.../ExampleBridgeProviderFactory.java`](../examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeProviderFactory.java) | +| 3 | (optional) A `DatafusionSource` subclass giving your source a short name | Scala/Java | [`src/main/scala/io/datafusion/spark/DatafusionSource.scala`](src/main/scala/io/datafusion/spark/DatafusionSource.scala) | see "Wiring it into Spark" below | + +An end-to-end runnable version of all three — in-memory table, factory, and a +PySpark script that scans, filters, and projects it — lives under +[`examples/python/`](../examples/python/). + +### 1. The Rust side + +Depend on the [`datafusion-spark-bridge`](bridge/) SDK crate and let it +generate the JNI surface; you supply one builder turning your option / +partition bytes into a concrete `TableProvider`: + +```rust +use std::sync::Arc; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +fn build_provider( + ctx: &BridgeContext, + options: &[u8], + partition: &[u8], +) -> JniResult> { + let opts = MyOptions::decode(options)?; + Ok(ctx.block_on(MyProvider::connect(opts, partition))?) +} + +export_bridge! { + // Underscore-mangled name of YOUR Java class declaring the native + // methods (dots -> underscores). Per-bridge names let several bridges + // coexist in one Spark JVM. + jni_class: "com_example_mybridge_BridgeNative", + build_provider: build_provider, +} +``` + +The macro's rustdoc lists the exact `static native` method set the named +Java class must declare; your factory routes the connector to it by +overriding `scanBackend()` (see section 2). One cdylib total: your provider +and the SDK's scan machinery are the same library, so there is no provider +hand-off across a binary boundary and no `datafusion-ffi` anywhere. The +builder receives empty partition bytes for the driver-side schema probe — +schema must not depend on per-partition state. + +[`examples/native/src/lib.rs`](../examples/native/src/lib.rs) +is a complete, commented version of this for a `MemTable`. + +### 2. The Java factory + +`BridgeProviderFactory` is the contract between Spark and your bridge. It +must have a no-arg constructor (executors instantiate it reflectively by +class name). The single required method is `scanBackend()` — Spark options +are encoded with `OptionsCodec` by default (decode them in Rust via +`datafusion_spark_bridge::options::decode_options`), and `listPartitions` +defaults to one whole-dataset partition: + +```java +public final class MyBridgeProviderFactory implements BridgeProviderFactory { + + @Override + public ScanBackend scanBackend() { + return new MyBridgeBackend(); // six one-line delegations to BridgeNative + } +} + +/** Declares the native methods generated by export_bridge! and loads the cdylib. */ +final class BridgeNative { + static { + NativeLibraryLoader.load(BridgeNative.class, "com/example/mybridge", "my_bridge"); + } + static native byte[] providerSchemaIpc(byte[] options, byte[] partition); + static native long createScan(byte[] options, byte[] partition, + int targetPartitions, int batchSize, String[] optionKeys, + String[] optionValues, String[] projectionColumns, byte[][] filterProtos); + static native int partitionCount(long scanHandle); + static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + static native void executeStream(long scanHandle, long ffiStreamAddr); + static native void closeScan(long scanHandle); +} +``` + +(`MyBridgeBackend implements ScanBackend` forwards each method to +`BridgeNative` — pure boilerplate the scaffold generates.) + +Override `encodeOptions` only if the bridge already has its own options +schema (e.g. a protobuf), and `listPartitions` when the dataset should split +into more than one Spark task: + +```java + @Override + public PartitionInfo[] listPartitions(byte[] optionsBytes) { + MySlice[] slices = MyBridgeNative.listSlices(optionsBytes); + PartitionInfo[] out = new PartitionInfo[slices.length]; + for (int i = 0; i < slices.length; i++) { + out[i] = new PartitionInfo(slices[i].id(), slices[i].payload(), slices[i].hosts()); + } + return out; + } +``` + +The remaining optional methods — `sharedScan`, `reportPartitioning`, and the +filter-aware `listPartitions(opts, filters)` overload — are covered in their +own sections below. Their javadoc in +[`BridgeProviderFactory.java`](src/main/java/io/datafusion/spark/BridgeProviderFactory.java) +is the authoritative contract. + +### 3. Wiring it into Spark + +Either pass your factory class per read: + +```python +df = (spark.read.format("datafusion") + .option("df.factory", "com.example.MyBridgeProviderFactory") + .option("url", "...") + .option("table", "my_dataset") + .load()) +``` + +or ship a ~10-line subclass so users get a short format name: + +```scala +class MyDataSource extends DatafusionSource { + override def shortName(): String = "my_format" + override protected def factoryFqcn(opts: CaseInsensitiveStringMap): String = + "com.example.MyBridgeProviderFactory" +} +``` + +registered via a +`META-INF/services/org.apache.spark.sql.sources.DataSourceRegister` file +(this module registers `datafusion` the same way — see +[`src/main/resources/META-INF/services/`](src/main/resources/META-INF/services/)). + +## Packaging your bridge + +The end-user experience to aim for is one artifact: + +```python +# spark.jars (or --packages) gets exactly one jar, then: +df = spark.read.format("my_format").option("url", "...").load() +``` + +Three pieces make that work: + +**Bundle your cdylib inside the jar.** Copy it into your jar's resources at +`///` and load it from your native +class's static initializer with the connector's loader — no hand-rolled +extraction code: + +```java +static { + NativeLibraryLoader.load(BridgeNative.class, "com/example/mybridge", "my_bridge"); +} +``` + +The pom side is one antrun copy execution plus per-host profiles; the +examples module is a complete working copy of the pattern (see the +`copy-example-bridge-cdylib` execution and the `native-*` profiles in +[`examples/pom.xml`](../examples/pom.xml), and the loader call in +[`ExampleBridgeNative.java`](../examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeNative.java)). +For a multi-platform jar, build the cdylib per platform in CI and copy each +into its own `//` directory before `mvn package` — the layout +supports them side by side. + +**Shade your dependencies into one fat jar** with `maven-shade-plugin`, so +users don't assemble a jar list: + +```xml + + org.apache.maven.plugins + maven-shade-plugin + + + package + shade + + + + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + +``` + +Include in the shaded jar: this connector (`datafusion-java-spark`), the core +jar (`datafusion-java` — exception classes and, if you push predicates, the +generated proto classes), the Arrow Java artifacts you compile against, and +your own classes + cdylib. Keep `spark-sql`/`scala-library` `provided` — the +cluster supplies them. + +**Do NOT relocate JNI-bound or JNI-loading packages.** JNI binds native +methods by the class's fully-qualified name; `arrow-c-data` and the Arrow +memory modules likewise load their own natives. Relocating +`io.datafusion.spark`, `org.apache.arrow`, or your own native class breaks +the symbol lookup at runtime. Practical consequences: + +- Ship a plain (unrelocated) fat jar. Two bridges in one Spark app then share + one copy of the connector classes — fine when they're built against the + same connector version, which is the only configuration we support anyway + (their cdylibs stay distinct via per-bridge JNI class names). +- Spark bundles its own (often older) Arrow. Since yours can't be relocated + away, have users set `spark.executor.userClassPathFirst=true` and + `spark.driver.userClassPathFirst=true` (the pyspark demo under + [`examples/python/`](../examples/python/) shows the working incantation), + or build with Arrow pinned to the cluster's version. + +## Spark tasks vs. DataFusion partitions + +This is the most important design decision when building a connector, so it +gets its own section. + +Spark parallelism and DataFusion parallelism are different things: + +- A **Spark task** is the unit Spark schedules onto an executor core. Each + task carries fixed overhead: scheduling on the driver, (de)serializing the + task, instantiating your factory, building a provider, planning a scan. +- A **DataFusion partition** is one output stream of a planned physical + query. A single plan usually has several. + +The connector supports two ways of mapping one onto the other: + +### Default mode: one Spark task per `PartitionInfo` + +`listPartitions` returns N entries → Spark runs N tasks. Each task calls +`createProvider(opts, partitionBytes)` with *its own* entry's payload, so each +task plans and scans only its slice. If DataFusion happens to plan that slice +into multiple internal partitions, they are merged into one stream for the +task — within a task there is no extra parallelism, by design (the +parallelism budget belongs to Spark). + +You control the mapping entirely through what you return from +`listPartitions`. Sizing guidance: + +- **Don't emit one `PartitionInfo` per tiny fragment.** A Spark task should + do meaningfully more work than its overhead — as a rule of thumb at least + ~100 ms of scan time, or order-100 MB of data (Spark's own file sources + default to 128 MB per task for the same reason). If your natural unit is a + small chunk (an object-store key, a time slice, a recording segment), + **bin-pack several into one entry**: `partitionBytes` is opaque, so encode + a *list* of chunk ids and have your `createProvider` materialise all of + them in one provider. +- **Watch the total task count.** The Spark driver schedules and tracks every + task; beyond the low thousands of tasks per stage you pay growing driver + CPU/memory and UI lag for no extra throughput once the cluster's cores are + saturated. A healthy target is roughly 2–3 tasks per available core, and + rarely more than a few thousand per scan. Tens of thousands of + single-digit-megabyte tasks is a smell — bin-pack first. +- **Locality and partition keys only exist here.** `preferredLocations` + (host affinity) and `HasPartitionKey`/`reportPartitioning` (shuffle + elision) are properties of `PartitionInfo` entries. If you need either, + use this mode. + +### Shared-scan mode: one Spark task per DataFusion partition + +When provider construction itself is expensive (remote metadata, connection +setup) or the dataset has thousands of small natural partitions, per-task +provider builds dominate. Opting in via + +```java +@Override +public boolean sharedScan(byte[] optionsBytes) { return true; } +``` + +flips the mapping: the provider is built **once per executor JVM per query** +(with empty `partitionBytes`), planned once, and Spark runs one task per +*DataFusion output partition* — task `i` streams plan partition `i` from the +executor-local cached plan. `listPartitions` is not called at all. + +The DataFusion partition count — and therefore the Spark task count — is +pinned by `spark.datafusion.sharedScan.targetPartitions` (default 8). The +value is resolved on the driver and shipped to executors, because +DataFusion's default would vary with each machine's core count and the +partition indices must mean the same thing everywhere. + +Choosing between the modes: + +| Choose | When | +|--------|------| +| Default (per-partition payload) | slices have host affinity, you want partition-key semantics, per-slice provider construction is cheap. Bin-pack small slices before abandoning this mode. | +| Shared-scan | provider construction is expensive, there are thousands of small partitions with no locality story, the workload is scan + filter + projection. Provider builds drop from one-per-task to one-per-executor (plus one driver probe per query). | + +Shared-scan's price of admission is a **determinism contract**: the +provider's schema, partitioning, and per-partition contents must be a pure +function of `optionsBytes`. Remote sources must pin a snapshot +(version/timestamp) inside the options. The connector fails tasks when an +executor's partition count diverges from the driver's, but equal counts with +different contents are undetectable by construction. The provider's +`ExecutionPlan` must also tolerate `execute(i)` being called more than once +per plan instance (Spark retries and speculatively re-executes tasks). Full +contract: `BridgeProviderFactory.sharedScan` javadoc. + +Shared-scan operational details: + +- Executor cache ([`SharedScanCache.scala`](src/main/scala/io/datafusion/spark/SharedScanCache.scala)): + entries keyed per query (`scanId`), refcounted by open readers, evicted + after an idle TTL. Build failures are not cached; eviction between task + waves just rebuilds. +- Spark conf (read on the driver at planning time, shipped to executors): + - `spark.datafusion.sharedScan.targetPartitions` (default 8) + - `spark.datafusion.sharedScan.batchSize` (default 8192) + - `spark.datafusion.sharedScan.idleTtlMs` (default 120000) + +## What the connector does for you + +- **Schema inference** — your provider's Arrow schema, widened, becomes the + Spark schema. Driver-side, one probe build with empty `partitionBytes`. +- **Type widening** — Spark's columnar readers reject several Arrow types + DataFusion happily produces. The SDK (inside your bridge's cdylib) + transparently casts + unsigned ints → wider signed, `Float16` → `Float32`, `Time*` → wider ints, + any-unit/tz `Timestamp` → microsecond, recursively through + `List`/`LargeList`/`FixedSizeList` (see + [`native/src/widening.rs`](native/src/widening.rs)). Caveat: unsigned types + nested inside `Struct`/`Map` are not yet covered. +- **Predicate pushdown** — Spark V2 `Predicate`s are translated to DataFusion + expressions ([`SparkPredicateTranslator.scala`](src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala)), + shipped as `datafusion-proto` bytes, and applied inside the native plan, so + your provider's `supports_filters_pushdown`/`scan` sees real Rust `Expr`s. + Anything untranslatable stays in Spark as a residual filter — over-claiming + is impossible by construction. +- **Column pruning** — Spark's required-columns projection becomes a + DataFusion projection on the native plan. +- **Partition-aware joins/aggregations** (default mode, optional) — declare + `reportPartitioning` + per-partition key values and Spark can elide + shuffles. See the javadoc on + [`ReportedPartitioning.java`](src/main/java/io/datafusion/spark/ReportedPartitioning.java) + and [`PartitionInfo.java`](src/main/java/io/datafusion/spark/PartitionInfo.java); + note Spark 3.3+ additionally requires + `spark.sql.sources.v2.bucketing.enabled=true` for storage-partitioned + joins. + +## What runs where + +| Phase | Where | Path | +| ----- | ----- | ---- | +| Schema inference | Driver | `factory.encodeOptions` → `backend.providerSchemaIpc(opts, EMPTY)` — bridge cdylib builds + widens the provider, returns the Arrow schema | +| Scan planning (default mode) | Driver | `factory.listPartitions(opts[, filters])` → one task per entry, with its `partitionBytes` + `preferredLocations` | +| Scan planning (shared-scan) | Driver | probe build (same code path executors use) → plan partition count `N` → `N` tasks | +| Predicate translation | Driver | `SparkPredicateTranslator` → proto bytes per pushed predicate | +| Per-task scan (default mode) | Executor | `backend.createScan(opts, partitionBytes, ...)` (build provider, widen, project, filter, plan) → stream whole plan | +| Per-task scan (shared-scan) | Executor | cache-acquire by `scanId` (first task builds) → stream plan partition `i` → release | + +The native machinery backing all of this is +[`bridge/src/scan.rs`](bridge/src/scan.rs), exported into each bridge's +cdylib by `export_bridge!` and reached through its [`ScanBackend`](src/main/java/io/datafusion/spark/ScanBackend.java). + +## Module layout + +``` +spark/ +├── src/main/java/io/datafusion/spark/ public SPI (Java on purpose: +│ bridge jars stay Scala-free) +│ BridgeProviderFactory.java <- the contract you implement +│ ScanBackend.java <- native scan surface (delegations +│ to your bridge's JNI class) +│ NativeLibraryLoader.java <- bundled-cdylib extraction/loading +│ PartitionInfo.java <- one entry = one Spark task +│ ReportedPartitioning.java <- optional shuffle-elision declaration +├── src/main/scala/io/datafusion/spark/ connector internals (DSv2 wiring, +│ readers, pushdown, shared-scan cache) +└── bridge/ datafusion-spark-bridge SDK rlib: + widening + scan machinery + + export_bridge! (the native side of + every bridge cdylib) +``` + +## Caveats + +- Pushed filter expressions are deserialized with DataFusion's default + logical-extension codec, which covers columns, literals, and built-in + functions. Anything the Spark-side translator can't express stays in Spark + as a residual filter, so coverage gaps cost performance, never + correctness. +- The bridge cdylib's DataFusion version is the SDK's: cargo resolves one + `datafusion` for your provider and the scan machinery together, pinned in + this repo's workspace [`Cargo.toml`](../Cargo.toml). Upgrading DataFusion + means rebuilding the bridge against a newer SDK. +- The SDK's Tokio runtime is per-cdylib and `Once`-gated; TLS-using bridges + should `Once`-gate their rustls install the same way. diff --git a/spark/bridge/Cargo.toml b/spark/bridge/Cargo.toml new file mode 100644 index 0000000..8ed4684 --- /dev/null +++ b/spark/bridge/Cargo.toml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark-bridge" +version = "0.1.0" +edition = "2021" +publish = false +description = "SDK for building Spark connector bridges over DataFusion TableProviders" + +[dependencies] +arrow = { workspace = true } +async-trait = { workspace = true } +datafusion = { workspace = true } +datafusion-jni-common = { path = "../../native-common" } +datafusion-proto = { workspace = true } +futures = { workspace = true } +jni = { workspace = true } +prost = { workspace = true } +tokio = { workspace = true } diff --git a/spark/bridge/src/lib.rs b/spark/bridge/src/lib.rs new file mode 100644 index 0000000..52ef1c1 --- /dev/null +++ b/spark/bridge/src/lib.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SDK for building Spark connector bridges over DataFusion `TableProvider`s. +//! +//! Everything the Spark connector needs DataFusion-side lives here: the +//! Spark-type [`widening`] layer, and the [`scan`] machinery (session from +//! pinned config, projection, proto filters, planning, partition streams). +//! A bridge cdylib depends on this crate and invokes [`export_bridge!`] with +//! a builder that constructs its concrete `TableProvider` from option / +//! partition bytes — one cdylib, no FFI provider boundary; the only foreign +//! interface is JNI plus Arrow's C stream for the results. + +pub mod options; +pub mod scan; +pub mod widening; + +// Re-exported so `export_bridge!` expansions resolve these crates inside the +// bridge author's crate without extra dependencies, and so builder signatures +// can be written against `datafusion_spark_bridge::datafusion::...`. +pub use datafusion; +pub use datafusion_jni_common::errors::JniResult; +pub use jni; + +use tokio::runtime::Handle; + +/// Execution environment handed to a bridge's provider builder. +/// +/// Provider construction frequently needs async IO (remote catalogs, +/// object-store metadata); run it on the bridge runtime via [`block_on`] +/// rather than creating a runtime of your own. +/// +/// [`block_on`]: BridgeContext::block_on +pub struct BridgeContext { + handle: &'static Handle, +} + +impl BridgeContext { + /// Used by `export_bridge!` expansions; not part of the public API. + #[doc(hidden)] + pub fn get() -> Self { + BridgeContext { + handle: runtime_handle(), + } + } + + /// The cdylib-wide Tokio runtime handle (also the runtime scans run on). + pub fn handle(&self) -> &Handle { + self.handle + } + + /// Block the current (JVM) thread on `fut`, driving it on the bridge + /// runtime. + pub fn block_on(&self, fut: F) -> F::Output { + self.handle.block_on(fut) + } +} + +/// Per-cdylib Tokio runtime (the singleton from `datafusion-jni-common`). +pub(crate) fn runtime_handle() -> &'static Handle { + datafusion_jni_common::runtime().handle() +} + +/// Generate the JNI entry points for a bridge cdylib. +/// +/// `jni_class` is the **underscore-mangled** binary name of the Java class +/// declaring the matching `native` methods: dots become underscores +/// (`com.example.mybridge.BridgeNative` → `"com_example_mybridge_BridgeNative"`). +/// If the class or package name itself contains an underscore, JNI mangling +/// requires it written as `_1`. Per-bridge class names are what let several +/// bridges coexist in one Spark JVM. +/// +/// `build_provider` is anything callable as +/// `Fn(&BridgeContext, &[u8], &[u8]) -> JniResult>`, +/// receiving the options bytes and partition bytes your JVM factory encoded. +/// The schema probe calls it with empty partition bytes; the scan path passes +/// each task's payload. Return errors boxed from `DataFusionError` to surface +/// as the typed `org.apache.datafusion.*` exception hierarchy. +/// +/// The generated Java-side surface (declare these as `static native` on the +/// class named by `jni_class`): +/// +/// ```java +/// static native byte[] providerSchemaIpc(byte[] options, byte[] partition); +/// static native long createScan(byte[] options, byte[] partition, +/// int targetPartitions, int batchSize, String[] optionKeys, +/// String[] optionValues, String[] projectionColumns, byte[][] filterProtos); +/// static native int partitionCount(long scanHandle); +/// static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); +/// static native void executeStream(long scanHandle, long ffiStreamAddr); +/// static native void closeScan(long scanHandle); +/// ``` +#[macro_export] +macro_rules! export_bridge { + (jni_class: $cls:literal, build_provider: $builder:expr $(,)?) => { + const _: () = { + use $crate::jni::objects::{JByteArray, JClass, JObjectArray}; + use $crate::jni::sys::{jbyteArray, jint, jlong}; + use $crate::jni::JNIEnv; + + fn __df_bridge_build( + env: &mut JNIEnv, + options: &JByteArray, + partition: &JByteArray, + ) -> $crate::JniResult> + { + let opts: Vec = if options.is_null() { + Vec::new() + } else { + env.convert_byte_array(options)? + }; + let part: Vec = if partition.is_null() { + Vec::new() + } else { + env.convert_byte_array(partition)? + }; + let ctx = $crate::BridgeContext::get(); + ($builder)(&ctx, opts.as_slice(), part.as_slice()) + } + + #[export_name = concat!("Java_", $cls, "_providerSchemaIpc")] + extern "system" fn __df_bridge_provider_schema_ipc<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + options: JByteArray<'local>, + partition: JByteArray<'local>, + ) -> jbyteArray { + $crate::scan::provider_schema_ipc(&mut env, |env| { + __df_bridge_build(env, &options, &partition) + }) + } + + #[export_name = concat!("Java_", $cls, "_createScan")] + #[allow(clippy::too_many_arguments)] + extern "system" fn __df_bridge_create_scan<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + options: JByteArray<'local>, + partition: JByteArray<'local>, + target_partitions: jint, + batch_size: jint, + option_keys: JObjectArray<'local>, + option_values: JObjectArray<'local>, + projection_columns: JObjectArray<'local>, + filter_protos: JObjectArray<'local>, + ) -> jlong { + $crate::scan::create_scan( + &mut env, + |env| __df_bridge_build(env, &options, &partition), + target_partitions, + batch_size, + &option_keys, + &option_values, + &projection_columns, + &filter_protos, + ) + } + + #[export_name = concat!("Java_", $cls, "_partitionCount")] + extern "system" fn __df_bridge_partition_count<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ) -> jint { + $crate::scan::partition_count(&mut env, handle) + } + + #[export_name = concat!("Java_", $cls, "_executeStreamPartition")] + extern "system" fn __df_bridge_execute_stream_partition<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + partition: jint, + ffi_stream_addr: jlong, + ) { + $crate::scan::execute_stream_partition(&mut env, handle, partition, ffi_stream_addr) + } + + #[export_name = concat!("Java_", $cls, "_executeStream")] + extern "system" fn __df_bridge_execute_stream<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ffi_stream_addr: jlong, + ) { + $crate::scan::execute_stream(&mut env, handle, ffi_stream_addr) + } + + #[export_name = concat!("Java_", $cls, "_closeScan")] + extern "system" fn __df_bridge_close_scan<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ) { + $crate::scan::close_scan(&mut env, handle) + } + }; + }; +} diff --git a/spark/bridge/src/options.rs b/spark/bridge/src/options.rs new file mode 100644 index 0000000..117ca9d --- /dev/null +++ b/spark/bridge/src/options.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for the connector's default options wire format. +//! +//! `BridgeProviderFactory.encodeOptions`'s default (`OptionsCodec` on the JVM +//! side) encodes the Spark options map as length-prefixed UTF-8 pairs, +//! sorted by key: big-endian `i32` entry count, then per entry key length, +//! key bytes, value length, value bytes. Key-sorting makes the bytes a pure +//! function of the map contents — the shared-scan determinism contract uses +//! the options bytes as the scan identity. +//! +//! Bridges using the default JVM encoding read their options here: +//! +//! ```ignore +//! let opts = datafusion_spark_bridge::options::decode_options(options_bytes)?; +//! let url = opts.get("url").ok_or("missing required option 'url'")?; +//! ``` +//! +//! The two implementations are pinned to each other by the shared fixture in +//! the tests below; `OptionsCodecTest` on the JVM side asserts the same +//! bytes. + +use std::collections::BTreeMap; + +/// Decode bytes produced by the JVM `OptionsCodec.encode` (or +/// [`encode_options`]). Empty input decodes as an empty map. +pub fn decode_options(bytes: &[u8]) -> Result, String> { + let mut out = BTreeMap::new(); + if bytes.is_empty() { + return Ok(out); + } + let mut cursor = Cursor { bytes, pos: 0 }; + let count = cursor.read_len("entry count")?; + for i in 0..count { + let key = cursor.read_string(&format!("key of entry {i}"))?; + let value = cursor.read_string(&format!("value of entry {i}"))?; + out.insert(key, value); + } + if cursor.pos != bytes.len() { + return Err(format!( + "options blob has {} trailing byte(s) after {count} entries", + bytes.len() - cursor.pos + )); + } + Ok(out) +} + +/// Encode in the same format (key-sorted via `BTreeMap`). Primarily for +/// tests and Rust-side tooling; production encoding normally happens on the +/// JVM driver. +pub fn encode_options(options: &BTreeMap) -> Vec { + let mut out = Vec::new(); + out.extend_from_slice(&(options.len() as i32).to_be_bytes()); + for (key, value) in options { + out.extend_from_slice(&(key.len() as i32).to_be_bytes()); + out.extend_from_slice(key.as_bytes()); + out.extend_from_slice(&(value.len() as i32).to_be_bytes()); + out.extend_from_slice(value.as_bytes()); + } + out +} + +struct Cursor<'a> { + bytes: &'a [u8], + pos: usize, +} + +impl Cursor<'_> { + fn read_len(&mut self, what: &str) -> Result { + if self.bytes.len() - self.pos < 4 { + return Err(format!("options blob truncated reading {what}")); + } + let raw = i32::from_be_bytes(self.bytes[self.pos..self.pos + 4].try_into().unwrap()); + self.pos += 4; + usize::try_from(raw).map_err(|_| format!("negative length for {what}: {raw}")) + } + + fn read_string(&mut self, what: &str) -> Result { + let len = self.read_len(&format!("length of {what}"))?; + if self.bytes.len() - self.pos < len { + return Err(format!("options blob truncated reading {what}")); + } + let slice = &self.bytes[self.pos..self.pos + len]; + self.pos += len; + String::from_utf8(slice.to_vec()).map_err(|e| format!("{what} is not UTF-8: {e}")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Shared fixture: must stay byte-identical to the one asserted by the + /// JVM-side `OptionsCodecTest`. {"table": "t1", "url": "grpc://h:1"} + /// encodes (sorted: table < url) as below. + fn fixture_bytes() -> Vec { + let mut b = Vec::new(); + b.extend_from_slice(&2i32.to_be_bytes()); + for (k, v) in [("table", "t1"), ("url", "grpc://h:1")] { + b.extend_from_slice(&(k.len() as i32).to_be_bytes()); + b.extend_from_slice(k.as_bytes()); + b.extend_from_slice(&(v.len() as i32).to_be_bytes()); + b.extend_from_slice(v.as_bytes()); + } + b + } + + #[test] + fn decodes_fixture() { + let map = decode_options(&fixture_bytes()).unwrap(); + assert_eq!(map.len(), 2); + assert_eq!(map.get("table").map(String::as_str), Some("t1")); + assert_eq!(map.get("url").map(String::as_str), Some("grpc://h:1")); + } + + #[test] + fn round_trips() { + let mut map = BTreeMap::new(); + map.insert("b".to_string(), "2".to_string()); + map.insert("a".to_string(), "1".to_string()); + map.insert("unicode".to_string(), "héllo→world".to_string()); + let bytes = encode_options(&map); + assert_eq!(decode_options(&bytes).unwrap(), map); + } + + #[test] + fn empty_input_is_empty_map() { + assert!(decode_options(&[]).unwrap().is_empty()); + let empty = encode_options(&BTreeMap::new()); + assert!(decode_options(&empty).unwrap().is_empty()); + } + + #[test] + fn rejects_truncation_and_trailing_bytes() { + let bytes = fixture_bytes(); + assert!(decode_options(&bytes[..bytes.len() - 1]) + .unwrap_err() + .contains("truncated")); + let mut extended = bytes.clone(); + extended.push(0); + assert!(decode_options(&extended).unwrap_err().contains("trailing")); + } +} diff --git a/spark/bridge/src/scan.rs b/spark/bridge/src/scan.rs new file mode 100644 index 0000000..ad27dff --- /dev/null +++ b/spark/bridge/src/scan.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Planning and execution of a Spark scan. +//! +//! Every function here is the body of one JNI entry point generated by a +//! bridge's `export_bridge!` expansion, which supplies only how the provider +//! is obtained, as a `make` closure. The provider is wrapped in a +//! [`WideningTableProvider`] here, so every bridge gets identical +//! Spark-compatible Arrow types. +//! +//! [`create_scan`] registers the widened provider on a private +//! `SessionContext` built from the caller-pinned config, applies the pruned +//! projection and the proto-encoded pushed filters, and plans exactly once. +//! The returned handle supports: +//! +//! - [`partition_count`] — output partitions of the physical plan +//! (shared-scan mode probes this on the driver and indexes tasks by it); +//! - [`execute_stream_partition`] — an independent stream over ONE plan +//! partition, concurrently callable from multiple JVM threads +//! (`ExecutionPlan` and `TaskContext` are `Send + Sync`; each call only +//! clones their `Arc`s). Re-executing the same partition index (Spark +//! task retry / speculative execution) opens its own stream, but only +//! succeeds when every operator in that partition's pipeline supports +//! repeated `execute()` — stateless scans do, `RepartitionExec` +//! pipelines do not; +//! - [`execute_stream`] — the whole plan as one stream (per-partition +//! mode, where the provider itself is the task's slice); +//! - [`close_scan`] — drop the plan. The single unsafe interleaving is +//! closing a handle that still has an in-flight call; the Java consumer +//! (the shared-scan cache) prevents it with a refcount covering every +//! open reader. +//! +//! Pinned-config determinism: the driver resolves `target_partitions` / +//! `batch_size` / option overrides once and ships them to every executor, so +//! a plan that yields N partitions on the driver yields N everywhere. This +//! module applies whatever it is handed and stays policy-free. + +use std::sync::Arc; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; +use datafusion::arrow::ipc::writer::StreamWriter; +use datafusion::catalog::TableProvider; +use datafusion::dataframe::DataFrame; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::{execute_stream as df_execute_stream, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; +use datafusion_jni_common::StreamingReader; +use datafusion_proto::logical_plan::from_proto::parse_expr; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; +use datafusion_proto::protobuf::LogicalExprNode; +use jni::objects::{JByteArray, JObjectArray, JString}; +use jni::sys::{jbyteArray, jint, jlong}; +use jni::JNIEnv; +use prost::Message; + +use crate::runtime_handle; +use crate::widening::WideningTableProvider; + +/// Registration name of the (single) provider on the scan's private context. +/// Never surfaces in SQL — the plan is built through the DataFrame API — so +/// no quoting/collision concerns. +const SCAN_TABLE_NAME: &str = "df_spark_scan"; + +struct ScanState { + /// Kept alive for the plan's lifetime; the registered provider and the + /// runtime env both hang off it. + _ctx: SessionContext, + plan: Arc, + task_ctx: Arc, +} + +fn widen(provider: Arc) -> Arc { + Arc::new(WideningTableProvider::new(provider)) +} + +fn collect_string_array(env: &mut JNIEnv, arr: &JObjectArray) -> JniResult> { + if arr.is_null() { + return Ok(Vec::new()); + } + let len = env.get_array_length(arr)?; + let mut owned: Vec = Vec::with_capacity(len as usize); + for i in 0..len { + let elem = env.get_object_array_element(arr, i)?; + let jstr: JString = elem.into(); + owned.push(env.get_string(&jstr)?.into()); + } + Ok(owned) +} + +fn collect_byte_arrays(env: &mut JNIEnv, arr: &JObjectArray) -> JniResult>> { + if arr.is_null() { + return Ok(Vec::new()); + } + let len = env.get_array_length(arr)?; + let mut owned: Vec> = Vec::with_capacity(len as usize); + for i in 0..len { + let elem = env.get_object_array_element(arr, i)?; + let bytes: JByteArray = elem.into(); + owned.push(env.convert_byte_array(&bytes)?); + } + Ok(owned) +} + +/// Driver-side schema probe: widened Arrow schema of the provider, as IPC +/// bytes (deserialized JVM-side with `MessageSerializer.deserializeSchema`). +/// `make` runs once; the provider drops before returning. +pub fn provider_schema_ipc( + env: &mut JNIEnv, + make: impl FnOnce(&mut JNIEnv) -> JniResult>, +) -> jbyteArray { + try_unwrap_or_throw(env, std::ptr::null_mut(), |env| -> JniResult { + let widened = widen(make(env)?); + let schema = widened.schema(); + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema.as_ref())?; + writer.finish()?; + } + let arr = env.byte_array_from_slice(&buf)?; + Ok(arr.into_raw()) + }) +} + +/// Build the scan: widen the provider from `make`, register it on a private +/// context with the pinned config, apply projection + pushed filters, plan +/// once. +/// +/// `target_partitions` / `batch_size` <= 0 leave the DataFusion defaults; +/// `option_keys`/`option_values` are parallel arrays of config overrides; +/// empty `projection_columns` selects all columns; each element of +/// `filter_protos` is a serialized `datafusion.LogicalExprNode`. +#[allow(clippy::too_many_arguments)] +pub fn create_scan( + env: &mut JNIEnv, + make: impl FnOnce(&mut JNIEnv) -> JniResult>, + target_partitions: jint, + batch_size: jint, + option_keys: &JObjectArray, + option_values: &JObjectArray, + projection_columns: &JObjectArray, + filter_protos: &JObjectArray, +) -> jlong { + try_unwrap_or_throw(env, 0, |env| -> JniResult { + let widened = widen(make(env)?); + + let keys = collect_string_array(env, option_keys)?; + let values = collect_string_array(env, option_values)?; + if keys.len() != values.len() { + return Err(format!( + "option key/value arrays differ in length: {} vs {}", + keys.len(), + values.len() + ) + .into()); + } + let projection = collect_string_array(env, projection_columns)?; + let filters = collect_byte_arrays(env, filter_protos)?; + + let mut config = SessionConfig::new(); + if target_partitions > 0 { + config = config.with_target_partitions(target_partitions as usize); + } + if batch_size > 0 { + config = config.with_batch_size(batch_size as usize); + } + for (key, value) in keys.iter().zip(values.iter()) { + config.options_mut().set(key, value)?; + } + + let ctx = SessionContext::new_with_config(config); + ctx.register_table(SCAN_TABLE_NAME, widened)?; + + let mut df: DataFrame = runtime_handle().block_on(ctx.table(SCAN_TABLE_NAME))?; + if !projection.is_empty() { + let refs: Vec<&str> = projection.iter().map(String::as_str).collect(); + df = df.select_columns(&refs)?; + } + for bytes in &filters { + let node = LogicalExprNode::decode(bytes.as_slice())?; + // TaskContext implements FunctionRegistry; the default codec is + // enough because the translator only emits column/literal/builtin + // expressions. + let registry = df.task_ctx(); + let expr = parse_expr(&node, ®istry, &DefaultLogicalExtensionCodec {})?; + df = df.filter(expr)?; + } + + // task_ctx() borrows; capture before create_physical_plan consumes df. + let task_ctx = Arc::new(df.task_ctx()); + let plan = runtime_handle().block_on(df.create_physical_plan())?; + + let state = ScanState { + _ctx: ctx, + plan, + task_ctx, + }; + Ok(Box::into_raw(Box::new(state)) as jlong) + }) +} + +/// Output partition count of the planned physical plan. +pub fn partition_count(env: &mut JNIEnv, handle: jlong) -> jint { + try_unwrap_or_throw(env, 0, |_env| -> JniResult { + if handle == 0 { + return Err("scan handle is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + Ok(state + .plan + .properties() + .output_partitioning() + .partition_count() as jint) + }) +} + +/// Open an independent stream over one plan partition, writing an +/// `FFI_ArrowArrayStream` into the caller-allocated struct at +/// `ffi_stream_addr`. +pub fn execute_stream_partition( + env: &mut JNIEnv, + handle: jlong, + partition: jint, + ffi_stream_addr: jlong, +) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + if ffi_stream_addr == 0 { + return Err("ffi stream address is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + + let partition_count = state + .plan + .properties() + .output_partitioning() + .partition_count(); + if partition < 0 || partition as usize >= partition_count { + return Err(format!( + "partition index {partition} out of range: plan has {partition_count} partition(s)" + ) + .into()); + } + + let plan = Arc::clone(&state.plan); + let task_ctx = Arc::clone(&state.task_ctx); + let schema: SchemaRef = plan.schema(); + + // ExecutionPlan::execute is synchronous, but operators may + // tokio::spawn at execute() time (RepartitionExec et al.), which + // requires a runtime context to be entered. + let stream = { + let _guard = runtime_handle().enter(); + plan.execute(partition as usize, task_ctx)? + }; + + let reader = StreamingReader { schema, stream }; + let ffi = FFI_ArrowArrayStream::new(Box::new(reader)); + unsafe { + std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi); + } + Ok(()) + }) +} + +/// Whole-plan stream for per-partition mode (the provider +/// itself is the task's slice, so all plan partitions merge into one reader). +pub fn execute_stream(env: &mut JNIEnv, handle: jlong, ffi_stream_addr: jlong) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + if ffi_stream_addr == 0 { + return Err("ffi stream address is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + + let plan = Arc::clone(&state.plan); + let task_ctx = Arc::clone(&state.task_ctx); + let schema: SchemaRef = plan.schema(); + + // execute_stream coalesces multi-partition plans behind one stream. + let stream = { + let _guard = runtime_handle().enter(); + df_execute_stream(plan, task_ctx)? + }; + + let reader = StreamingReader { schema, stream }; + let ffi = FFI_ArrowArrayStream::new(Box::new(reader)); + unsafe { + std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi); + } + Ok(()) + }) +} + +/// Drop the planned scan. Must not race an in-flight stream-open on the same +/// handle; the Java consumer's refcount enforces this. +pub fn close_scan(env: &mut JNIEnv, handle: jlong) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + drop(unsafe { Box::from_raw(handle as *mut ScanState) }); + Ok(()) + }) +} diff --git a/spark/bridge/src/widening.rs b/spark/bridge/src/widening.rs new file mode 100644 index 0000000..86c4abf --- /dev/null +++ b/spark/bridge/src/widening.rs @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernel-level Arrow type widening for Spark consumption. +//! +//! Spark 3.5's `ArrowColumnVector` has no accessor for unsigned ints, Time*, +//! Float16, or non-microsecond Timestamp. The widening machinery here wraps +//! an inner `TableProvider` with one that exposes a "widened" schema — +//! UInt*→Int wider, Float16→Float32, Time*→Int wider, Timestamp(*, tz)→ +//! Timestamp(Microsecond, tz), recursing into List/LargeList/FixedSizeList +//! children — and applies `arrow::compute::cast` to each produced +//! RecordBatch column-wise. No SQL, no SessionContext, no view machinery. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Schema as ArrowSchema, SchemaRef, TimeUnit}; +use async_trait::async_trait; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{DataFusionError, Result}; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, +}; +use futures::stream::StreamExt; + +/// Compute the cast-target DataType for an Arrow type not directly readable +/// by Spark's `ArrowColumnVector`. Returns `None` if the type passes through. +pub fn arrow_cast_widening(dt: &DataType) -> Option { + match dt { + DataType::UInt8 => Some(DataType::Int16), + DataType::UInt16 => Some(DataType::Int32), + DataType::UInt32 => Some(DataType::Int64), + // UInt64 → Int64: lossy for values ≥ 2^63. Documented in REARCHITECTURE.md. + DataType::UInt64 => Some(DataType::Int64), + DataType::Float16 => Some(DataType::Float32), + DataType::Time32(_) => Some(DataType::Int32), + DataType::Time64(_) => Some(DataType::Int64), + DataType::Timestamp(unit, tz) => { + if *unit == TimeUnit::Microsecond { + None + } else { + Some(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())) + } + } + DataType::List(field) => arrow_cast_widening(field.data_type()) + .map(|inner| DataType::List(widened_child(field, inner))), + DataType::LargeList(field) => arrow_cast_widening(field.data_type()) + .map(|inner| DataType::LargeList(widened_child(field, inner))), + // Spark 3.5's ArrowColumnVector cannot read FixedSizeList at all, so + // always convert it to a (variable) List — which Spark maps to + // ArrayType — widening the child element type when needed too. + DataType::FixedSizeList(field, _size) => { + let child = match arrow_cast_widening(field.data_type()) { + Some(inner) => widened_child(field, inner), + None => Arc::clone(field), + }; + Some(DataType::List(child)) + } + _ => None, + } +} + +fn widened_child(field: &Arc, new_type: DataType) -> Arc { + Arc::new(Field::new(field.name(), new_type, field.is_nullable())) +} + +/// Build the widened schema by walking inner fields and replacing types. +/// Returns the widened schema plus per-column target types (None where no cast). +fn widened_schema(inner: &ArrowSchema) -> (SchemaRef, Vec>) { + let mut fields = Vec::with_capacity(inner.fields().len()); + let mut targets = Vec::with_capacity(inner.fields().len()); + for f in inner.fields() { + match arrow_cast_widening(f.data_type()) { + Some(target) => { + fields.push(Arc::new(Field::new( + f.name(), + target.clone(), + f.is_nullable(), + ))); + targets.push(Some(target)); + } + None => { + fields.push(Arc::clone(f)); + targets.push(None); + } + } + } + (Arc::new(ArrowSchema::new(fields)), targets) +} + +/// TableProvider wrapping an inner provider, exposing a widened schema and +/// emitting RecordBatches whose columns have been cast to the widened types. +#[derive(Debug)] +pub struct WideningTableProvider { + inner: Arc, + widened: SchemaRef, + /// Targets indexed by the inner-schema column position; `None` = pass through. + targets: Vec>, +} + +impl WideningTableProvider { + pub fn new(inner: Arc) -> Self { + let (widened, targets) = widened_schema(&inner.schema()); + Self { + inner, + widened, + targets, + } + } +} + +#[async_trait] +impl TableProvider for WideningTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.widened) + } + + fn table_type(&self) -> TableType { + self.inner.table_type() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + self.inner.supports_filters_pushdown(filters) + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let inner_plan = self.inner.scan(session, projection, filters, limit).await?; + let (projected_widened, projected_targets) = match projection { + Some(idxs) => { + let fields: Vec> = idxs + .iter() + .map(|i| Arc::clone(&self.widened.fields()[*i])) + .collect(); + let targets: Vec> = + idxs.iter().map(|i| self.targets[*i].clone()).collect(); + (Arc::new(ArrowSchema::new(fields)) as SchemaRef, targets) + } + None => (Arc::clone(&self.widened), self.targets.clone()), + }; + Ok(Arc::new(WideningExec::new( + inner_plan, + projected_widened, + projected_targets, + ))) + } +} + +/// ExecutionPlan that runs the inner plan and casts each output RecordBatch +/// column-wise per the supplied targets. Pure stream-map wrapper; no +/// buffering, no internal state. +pub struct WideningExec { + inner: Arc, + schema: SchemaRef, + /// One entry per output column; `None` = pass through. + targets: Vec>, + properties: Arc, +} + +impl WideningExec { + fn new( + inner: Arc, + schema: SchemaRef, + targets: Vec>, + ) -> Self { + let inner_props = inner.properties(); + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + inner_props.partitioning.clone(), + inner_props.emission_type, + inner_props.boundedness, + )); + Self { + inner, + schema, + targets, + properties, + } + } +} + +impl fmt::Debug for WideningExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WideningExec") + .field("schema", &self.schema) + .field("targets", &self.targets) + .finish() + } +} + +impl DisplayAs for WideningExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let cast_count = self.targets.iter().filter(|t| t.is_some()).count(); + write!(f, "WideningExec: casts={cast_count}") + } +} + +impl ExecutionPlan for WideningExec { + fn name(&self) -> &str { + "WideningExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return Err(DataFusionError::Internal( + "WideningExec::with_new_children expects exactly one child".to_string(), + )); + } + Ok(Arc::new(WideningExec::new( + children.into_iter().next().unwrap(), + Arc::clone(&self.schema), + self.targets.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let inner_stream = self.inner.execute(partition, context)?; + let schema = Arc::clone(&self.schema); + let targets = self.targets.clone(); + let mapped = inner_stream.map(move |batch_res| match batch_res { + Err(e) => Err(e), + Ok(batch) => cast_batch(&batch, &schema, &targets), + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + mapped, + ))) + } +} + +fn cast_batch( + batch: &RecordBatch, + out_schema: &SchemaRef, + targets: &[Option], +) -> Result { + if batch.num_columns() != targets.len() { + return Err(DataFusionError::Internal(format!( + "WideningExec: produced batch has {} columns, expected {}", + batch.num_columns(), + targets.len() + ))); + } + let mut new_cols = Vec::with_capacity(batch.num_columns()); + for (col, target) in batch.columns().iter().zip(targets.iter()) { + match target { + Some(t) => new_cols.push(cast(col, t).map_err(DataFusionError::from)?), + None => new_cols.push(Arc::clone(col)), + } + } + RecordBatch::try_new(Arc::clone(out_schema), new_cols).map_err(DataFusionError::from) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn unsigned_ints_widen_to_signed_wider() { + assert_eq!(arrow_cast_widening(&DataType::UInt8), Some(DataType::Int16)); + assert_eq!( + arrow_cast_widening(&DataType::UInt16), + Some(DataType::Int32) + ); + assert_eq!( + arrow_cast_widening(&DataType::UInt32), + Some(DataType::Int64) + ); + assert_eq!( + arrow_cast_widening(&DataType::UInt64), + Some(DataType::Int64) + ); + } + + #[test] + fn float16_widens_to_float32() { + assert_eq!( + arrow_cast_widening(&DataType::Float16), + Some(DataType::Float32) + ); + } + + #[test] + fn time_widens_to_int() { + assert_eq!( + arrow_cast_widening(&DataType::Time32(TimeUnit::Millisecond)), + Some(DataType::Int32) + ); + assert_eq!( + arrow_cast_widening(&DataType::Time64(TimeUnit::Nanosecond)), + Some(DataType::Int64) + ); + } + + #[test] + fn timestamp_normalizes_unit_preserving_tz() { + let ns = DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))); + assert_eq!( + arrow_cast_widening(&ns), + Some(DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from("UTC")) + )) + ); + let us_no_tz = DataType::Timestamp(TimeUnit::Microsecond, None); + assert_eq!(arrow_cast_widening(&us_no_tz), None); + } + + #[test] + fn list_recurses_into_children() { + let inner_field = Arc::new(Field::new("item", DataType::UInt16, true)); + let list_ty = DataType::List(inner_field); + let widened = arrow_cast_widening(&list_ty).expect("should widen"); + match widened { + DataType::List(field) => assert_eq!(field.data_type(), &DataType::Int32), + other => panic!("expected List, got {other:?}"), + } + } + + #[test] + fn signed_int_passes_through() { + assert_eq!(arrow_cast_widening(&DataType::Int32), None); + assert_eq!(arrow_cast_widening(&DataType::Utf8), None); + } +} diff --git a/spark/bridge/tests/export_macro.rs b/spark/bridge/tests/export_macro.rs new file mode 100644 index 0000000..14751c8 --- /dev/null +++ b/spark/bridge/tests/export_macro.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Compile-level test of `export_bridge!`: the macro must expand to valid +//! `extern "system"` items against a plain builder function. JNI entry +//! points can't be exercised without a live JVM, so the assertion here is +//! that this test crate links with the generated symbols present. + +use std::sync::Arc; + +use datafusion_spark_bridge::datafusion::arrow::datatypes::Schema; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::datafusion::datasource::MemTable; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +fn build_provider( + _ctx: &BridgeContext, + _options: &[u8], + _partition: &[u8], +) -> JniResult> { + let schema = Arc::new(Schema::empty()); + let table = MemTable::try_new(schema, vec![vec![]])?; + Ok(Arc::new(table)) +} + +export_bridge! { + jni_class: "com_example_testbridge_BridgeNative", + build_provider: build_provider, +} + +#[test] +fn builder_contract_runs_outside_jvm() { + // Expansion + linking is the macro test; this additionally runs the + // builder through the same BridgeContext the expansion hands it. + let ctx = BridgeContext::get(); + let provider = build_provider(&ctx, &[], &[]).expect("builder failed"); + assert_eq!(provider.schema().fields().len(), 0); +} diff --git a/spark/pom.xml b/spark/pom.xml new file mode 100644 index 0000000..90e4e6d --- /dev/null +++ b/spark/pom.xml @@ -0,0 +1,150 @@ + + + + 4.0.0 + + + org.apache.datafusion + datafusion-java-parent + 0.2.0-SNAPSHOT + + + datafusion-java-spark_2.13 + jar + + Apache DataFusion Java Spark Connector + + Generic Spark DataSource V2 connector for DataFusion TableProviders. + Domain bridges implement BridgeProviderFactory over a cdylib built + with the datafusion-spark-bridge Rust SDK; this module supplies the + Spark plumbing, predicate translation, Arrow-to-Spark schema + conversion, and the shared-scan cache. Pure JVM artifact — the + native code ships inside each bridge's own jar. + + + + 2.13 + 2.13.14 + 3.5.7 + + + + + org.scala-lang + scala-library + ${scala.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + + org.apache.datafusion + datafusion-java + + + + org.apache.arrow + arrow-vector + + + org.apache.arrow + arrow-c-data + + + org.apache.arrow + arrow-memory-netty + runtime + + + + org.scalatest + scalatest_${scala.compat.version} + 3.2.18 + test + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.8.1 + + + + compile + testCompile + + + + + ${scala.version} + + -deprecation + -feature + -unchecked + + all + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-opens=java.base/java.nio=ALL-UNNAMED + true + + + + org.scalatest + scalatest-maven-plugin + 2.2.0 + + ${project.build.directory}/scalatest-reports + . + WDF TestSuite.txt + --add-opens=java.base/java.nio=ALL-UNNAMED + + + + test + test + + + + + + + diff --git a/spark/scaffold/bridge-template/.gitignore b/spark/scaffold/bridge-template/.gitignore new file mode 100644 index 0000000..e2777a5 --- /dev/null +++ b/spark/scaffold/bridge-template/.gitignore @@ -0,0 +1,3 @@ +target/ +native/target/ +*.class diff --git a/spark/scaffold/bridge-template/README.md b/spark/scaffold/bridge-template/README.md new file mode 100644 index 0000000..8259e53 --- /dev/null +++ b/spark/scaffold/bridge-template/README.md @@ -0,0 +1,54 @@ +# __PREFIX__ Spark Bridge + +A Spark DataSource V2 connector for the `__FORMAT__` format, built on the +[datafusion-java Spark connector](https://github.com/apache/datafusion-java) +and its `datafusion-spark-bridge` Rust SDK. Generated by `spark/scaffold/new_bridge.py`; +the only code you need to touch is marked `TODO`. + +## What's here + +| File | Role | +| --- | --- | +| `native/src/lib.rs` | **Your provider.** `build_provider` turns option/partition bytes into a DataFusion `TableProvider` (demo: an in-memory table). `export_bridge!` generates the whole JNI surface. | +| `src/main/java/.../BridgeNative.java` | Declares the generated native methods and loads the bundled cdylib. Must keep the name/package the Rust macro was generated with. | +| `src/main/java/.../__PREFIX__ScanBackend.java` | Routes the connector's scan calls to `BridgeNative`. Pure delegation. | +| `src/main/java/.../__PREFIX__ProviderFactory.java` | The connector contract. Override `listPartitions` / `sharedScan` / `encodeOptions` here as the bridge grows. | +| `src/main/java/.../__PREFIX__DataSource.java` + `META-INF/services/...` | `spark.read.format("__FORMAT__")`. | +| `pom.xml` | One shaded fat jar with the cdylib bundled inside. | + +## Build + +```bash +# 0. Once: install datafusion-java to your local Maven repo (from its checkout): +# cargo build && ./mvnw install -DskipTests + +# 1. The cdylib: +cargo build --manifest-path native/Cargo.toml + +# 2. The shaded jar (target/__CRATE__-0.1.0-SNAPSHOT.jar): +mvn package +``` + +Release builds: `cargo build --release --manifest-path native/Cargo.toml` and +`mvn package -Dnative.profile=release`. + +## Use + +```python +df = (spark.read.format("__FORMAT__") + .option("rows", "5") # demo option; replace with your own + .load()) +df.show() +``` + +with the shaded jar on `spark.jars`. `python3 smoke_test.py` runs exactly this +against a local Spark (needs `SPARK_HOME` pointing at a Scala 2.13 distro). + +## Where to go next + +- Replace the demo `MemTable` in `native/src/lib.rs` with your real provider. +- Split the dataset into Spark tasks (`listPartitions`) or switch to + shared-scan mode (`sharedScan`) — task-sizing guidance lives in the + connector's `spark/README.md` ("Spark tasks vs. DataFusion partitions"). +- Multi-platform jars: build the cdylib per platform in CI and copy each into + `src`-side `//` directories before `mvn package`. diff --git a/spark/scaffold/bridge-template/native/Cargo.toml b/spark/scaffold/bridge-template/native/Cargo.toml new file mode 100644 index 0000000..c0d2996 --- /dev/null +++ b/spark/scaffold/bridge-template/native/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "__CRATE__-native" +version = "0.1.0" +edition = "2021" +publish = false + +# Standalone crate: the empty [workspace] table stops cargo from adopting +# this crate into any workspace found in a parent directory. +[workspace] + +[lib] +name = "__LIB__" +crate-type = ["cdylib"] + +[dependencies] +# TODO: replace the path with a git or crates.io dependency once you build +# outside a local datafusion-java checkout. +datafusion-spark-bridge = { path = "__BRIDGE_SDK_PATH__" } + +[profile.release] +strip = "debuginfo" diff --git a/spark/scaffold/bridge-template/native/src/lib.rs b/spark/scaffold/bridge-template/native/src/lib.rs new file mode 100644 index 0000000..8439217 --- /dev/null +++ b/spark/scaffold/bridge-template/native/src/lib.rs @@ -0,0 +1,59 @@ +//! Native side of the `__FORMAT__` Spark bridge. +//! +//! `export_bridge!` generates the whole JNI surface for +//! `__PKG__.BridgeNative`; the only code you own is [`build_provider`], +//! which turns the option/partition bytes your JVM factory encoded into a +//! concrete `TableProvider`. Everything downstream — type widening, session +//! construction, projection, pushed filters, planning, partition streams — +//! is the SDK's job. + +use std::sync::Arc; + +use datafusion_spark_bridge::datafusion::arrow::array::{Int64Array, StringArray}; +use datafusion_spark_bridge::datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion_spark_bridge::datafusion::arrow::record_batch::RecordBatch; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::datafusion::datasource::MemTable; +use datafusion_spark_bridge::options::decode_options; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +/// Build the provider for one scan. +/// +/// `options` is whatever the JVM factory's `encodeOptions` produced — with +/// the default factory that is the connector's `OptionsCodec` format, decoded +/// below into a string map. `partition` is the per-task payload from +/// `listPartitions` (empty for the schema probe, for shared-scan mode, and +/// for the default single-partition layout). +/// +/// TODO: replace the demo `MemTable` with your real `TableProvider`. For +/// async construction (remote catalogs, object stores), use +/// `ctx.block_on(...)`. +fn build_provider( + _ctx: &BridgeContext, + options: &[u8], + _partition: &[u8], +) -> JniResult> { + let opts = decode_options(options)?; + let rows: i64 = match opts.get("rows") { + Some(v) => v + .parse() + .map_err(|e| format!("option 'rows' is not an integer: {e}"))?, + None => 3, + }; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("greeting", DataType::Utf8, false), + ])); + let ids = Int64Array::from_iter_values(0..rows); + let greetings = + StringArray::from_iter_values((0..rows).map(|i| format!("hello from __FORMAT__ #{i}"))); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(greetings)])?; + + Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?)) +} + +export_bridge! { + jni_class: "__JNI_CLASS__", + build_provider: build_provider, +} diff --git a/spark/scaffold/bridge-template/pom.xml b/spark/scaffold/bridge-template/pom.xml new file mode 100644 index 0000000..4e8c2b6 --- /dev/null +++ b/spark/scaffold/bridge-template/pom.xml @@ -0,0 +1,174 @@ + + + 4.0.0 + + __PKG__ + __CRATE__ + 0.1.0-SNAPSHOT + jar + + __PREFIX__ Spark Bridge + + + UTF-8 + 17 + 2.13 + 3.5.7 + __DF_JAVA_VERSION__ + + debug + + + + + + org.apache.datafusion + datafusion-java-spark_${scala.compat.version} + ${datafusion.java.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 3.1.0 + + + copy-bridge-cdylib + process-classes + run + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.6.0 + + + package + shade + + false + + + + org.scala-lang:scala-library + + + + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + + + + + native-linux-amd64 + + unixlinuxamd64 + + + linux + x86_64 + lib__LIB__.so + + + + native-linux-x86_64 + + unixlinuxx86_64 + + + linux + x86_64 + lib__LIB__.so + + + + native-linux-aarch64 + + unixlinuxaarch64 + + + linux + aarch64 + lib__LIB__.so + + + + native-mac-x86_64 + + macx86_64 + + + darwin + x86_64 + lib__LIB__.dylib + + + + native-mac-aarch64 + + macaarch64 + + + darwin + aarch64 + lib__LIB__.dylib + + + + diff --git a/spark/scaffold/bridge-template/smoke_test.py b/spark/scaffold/bridge-template/smoke_test.py new file mode 100644 index 0000000..ca3925e --- /dev/null +++ b/spark/scaffold/bridge-template/smoke_test.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Smoke test: scan the __FORMAT__ bridge's demo table through PySpark. + +Prerequisites: + - cargo build --manifest-path native/Cargo.toml (the bridge cdylib) + - mvn package (the shaded jar) + - a Scala 2.13 Spark distribution; the PyPI pyspark wheel embeds 2.12, so + point SPARK_HOME at e.g. spark-3.5.7-bin-hadoop3-scala2.13. + +Run: python3 smoke_test.py +""" + +import glob +import os +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent + +spark_home = os.environ.get("SPARK_HOME") +if not spark_home or not Path(spark_home, "jars").is_dir(): + sys.exit("Set SPARK_HOME to a Scala 2.13 Spark distribution.") +os.environ["SPARK_HOME"] = spark_home + +jars = glob.glob(str(PROJECT_ROOT / "target" / "__CRATE__-*.jar")) +jars = [j for j in jars if not j.endswith(("-sources.jar", "-javadoc.jar"))] +if not jars: + sys.exit("Shaded jar not found under target/. Run 'mvn package' first.") +jar = jars[0] + +from pyspark.sql import SparkSession # noqa: E402 + +spark = ( + SparkSession.builder.appName("__FORMAT__-smoke") + .master("local[2]") + .config("spark.jars", jar) + # extraClassPath PREPENDS, so the fat jar's Arrow wins over Spark's + # bundled (older) copy on both driver and executors. + .config("spark.driver.extraClassPath", jar) + .config("spark.executor.extraClassPath", jar) + .config("spark.driver.extraJavaOptions", "--add-opens=java.base/java.nio=ALL-UNNAMED") + .config("spark.executor.extraJavaOptions", "--add-opens=java.base/java.nio=ALL-UNNAMED") + .getOrCreate() +) + +df = spark.read.format("__FORMAT__").option("rows", "5").load() +df.printSchema() +df.show(truncate=False) +count = df.count() +filtered = df.filter("id >= 2").count() +spark.stop() + +assert count == 5, f"expected 5 rows, got {count}" +assert filtered == 3, f"expected 3 rows with id >= 2, got {filtered}" +print("smoke test OK: 5 rows scanned, filter returned 3") diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java new file mode 100644 index 0000000..7cf02de --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java @@ -0,0 +1,40 @@ +package __PKG__; + +import io.datafusion.spark.NativeLibraryLoader; + +/** + * JNI surface generated on the Rust side by {@code export_bridge!} with {@code jni_class = + * "__JNI_CLASS__"} — the mangled binary name of THIS class. Renaming or moving this class + * requires regenerating the Rust macro invocation to match. + * + *

The cdylib is bundled in this jar under {@code __PKG_PATH__///} (see the antrun + * execution in pom.xml) and extracted/loaded once per JVM by the connector's loader. + */ +final class BridgeNative { + + private BridgeNative() {} + + static { + NativeLibraryLoader.load(BridgeNative.class, "__PKG_PATH__", "__LIB__"); + } + + static native byte[] providerSchemaIpc(byte[] options, byte[] partition); + + static native long createScan( + byte[] options, + byte[] partition, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos); + + static native int partitionCount(long scanHandle); + + static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + + static native void executeStream(long scanHandle, long ffiStreamAddr); + + static native void closeScan(long scanHandle); +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java new file mode 100644 index 0000000..c888a0d --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java @@ -0,0 +1,21 @@ +package __PKG__; + +import io.datafusion.spark.DatafusionSource; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Gives the bridge its Spark format name: {@code spark.read.format("__FORMAT__")}. Registered via + * {@code META-INF/services/org.apache.spark.sql.sources.DataSourceRegister}. + */ +public class __PREFIX__DataSource extends DatafusionSource { + + @Override + public String shortName() { + return "__FORMAT__"; + } + + @Override + public String factoryFqcn(CaseInsensitiveStringMap options) { + return __PREFIX__ProviderFactory.class.getName(); + } +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java new file mode 100644 index 0000000..03498e4 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java @@ -0,0 +1,28 @@ +package __PKG__; + +import io.datafusion.spark.BridgeProviderFactory; +import io.datafusion.spark.ScanBackend; + +/** + * The bridge's contract with the Spark connector: the provider is built inside this bridge's own + * cdylib, and {@link #scanBackend()} is the only required method. + * + *

Useful optional overrides (see their javadoc on {@link BridgeProviderFactory}): + * + *

    + *
  • {@code encodeOptions} — only if you have your own options schema; the default ships the + * Spark options map in the connector's {@code OptionsCodec} format, which the Rust side + * already decodes via {@code datafusion_spark_bridge::options::decode_options}. + *
  • {@code listPartitions} — the default is ONE whole-dataset partition. Override to split + * into more Spark tasks (with optional preferred hosts and partition keys), or… + *
  • {@code sharedScan} — …opt into shared-scan mode: one provider per executor, one Spark + * task per DataFusion output partition. Mind the determinism contract. + *
+ */ +public final class __PREFIX__ProviderFactory implements BridgeProviderFactory { + + @Override + public ScanBackend scanBackend() { + return new __PREFIX__ScanBackend(); + } +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java new file mode 100644 index 0000000..eb78dd1 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java @@ -0,0 +1,53 @@ +package __PKG__; + +import io.datafusion.spark.ScanBackend; + +/** Routes the connector's scan calls to this bridge's own cdylib. Pure delegation. */ +public final class __PREFIX__ScanBackend implements ScanBackend { + + @Override + public byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes) { + return BridgeNative.providerSchemaIpc(options, partitionBytes); + } + + @Override + public long createScan( + byte[] options, + byte[] partitionBytes, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos) { + return BridgeNative.createScan( + options, + partitionBytes, + targetPartitions, + batchSize, + optionKeys, + optionValues, + projectionColumns, + filterProtos); + } + + @Override + public int partitionCount(long scanHandle) { + return BridgeNative.partitionCount(scanHandle); + } + + @Override + public void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr) { + BridgeNative.executeStreamPartition(scanHandle, partition, ffiStreamAddr); + } + + @Override + public void executeStream(long scanHandle, long ffiStreamAddr) { + BridgeNative.executeStream(scanHandle, ffiStreamAddr); + } + + @Override + public void closeScan(long scanHandle) { + BridgeNative.closeScan(scanHandle); + } +} diff --git a/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..e72a178 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +__PKG__.__PREFIX__DataSource diff --git a/spark/scaffold/new_bridge.py b/spark/scaffold/new_bridge.py new file mode 100644 index 0000000..03b8de7 --- /dev/null +++ b/spark/scaffold/new_bridge.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Scaffold a new Spark bridge project from spark/scaffold/bridge-template/. + +Stamps out a standalone project (Maven + Cargo) wired to the +datafusion-spark-bridge SDK: a Rust cdylib with `export_bridge!` and a demo +in-memory provider, the four Java classes (native surface, ScanBackend, +factory, DataSource shim), the DataSourceRegister service file, a shaded-jar +pom that bundles the cdylib, a pyspark smoke test, and a README with the +build/run commands. + +Usage: + python3 spark/scaffold/new_bridge.py --name acme --package com.example.acme \ + [--output DIR] [--datafusion-java REPO_ROOT] + +`--name` is the Spark format short name (spark.read.format("acme")); it also +derives the class prefix (acme -> Acme, my_format -> MyFormat), the cargo +crate name, and the cdylib name. Stdlib only; no dependencies. +""" + +import argparse +import re +import sys +from pathlib import Path + +TEMPLATE_DIR = Path(__file__).resolve().parent / "bridge-template" + + +def jni_mangle(binary_class_name: str) -> str: + """JNI symbol mangling for a class's binary name: '_' -> '_1', '.' -> '_'.""" + return binary_class_name.replace("_", "_1").replace(".", "_") + + +def class_prefix(name: str) -> str: + return "".join(part.capitalize() for part in name.split("_")) + + +def validate(name: str, package: str) -> None: + if not re.fullmatch(r"[a-z][a-z0-9_]*", name): + sys.exit(f"--name must match [a-z][a-z0-9_]*, got: {name}") + if not re.fullmatch(r"[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+", package): + sys.exit(f"--package must be a dotted lowercase Java package, got: {package}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--name", required=True, help="Spark format short name, e.g. acme") + parser.add_argument( + "--package", required=True, help="Java package for the bridge, e.g. com.example.acme" + ) + parser.add_argument( + "--output", + help="Directory to create (default: ./-spark-bridge; must not exist)", + ) + parser.add_argument( + "--datafusion-java", + help="datafusion-java repo root providing the spark/bridge SDK crate " + "(default: the repo this script lives in)", + ) + args = parser.parse_args() + + validate(args.name, args.package) + prefix = class_prefix(args.name) + crate = args.name.replace("_", "-") + "-spark-bridge" + lib = args.name + "_spark_bridge" + repo = Path(args.datafusion_java).resolve() if args.datafusion_java else TEMPLATE_DIR.parents[2] + sdk_path = repo / "spark" / "bridge" + if not (sdk_path / "Cargo.toml").is_file(): + sys.exit(f"datafusion-spark-bridge crate not found at {sdk_path}") + out = Path(args.output) if args.output else Path.cwd() / crate + if out.exists(): + sys.exit(f"output directory already exists: {out}") + + tokens = { + "__PKG__": args.package, + "__PKG_PATH__": args.package.replace(".", "/"), + "__JNI_CLASS__": jni_mangle(args.package + ".BridgeNative"), + "__PREFIX__": prefix, + "__FORMAT__": args.name, + "__CRATE__": crate, + "__LIB__": lib, + "__BRIDGE_SDK_PATH__": str(sdk_path), + "__DF_JAVA_VERSION__": read_repo_version(repo), + } + + generated = [] + for src in sorted(TEMPLATE_DIR.rglob("*")): + if not src.is_file(): + continue + rel = str(src.relative_to(TEMPLATE_DIR)) + for token, value in tokens.items(): + rel = rel.replace(token, value) + dst = out / rel + dst.parent.mkdir(parents=True, exist_ok=True) + text = src.read_text() + for token, value in tokens.items(): + text = text.replace(token, value) + dst.write_text(text) + generated.append(rel) + + print(f"Generated {len(generated)} files under {out}:") + for rel in generated: + print(f" {rel}") + print() + print("Next steps (see the generated README.md):") + print(f" 1. cd {out}") + print(" 2. cargo build --release --manifest-path native/Cargo.toml") + print(" 3. mvn package -Dnative.profile=release") + print(f" 4. spark.read.format(\"{args.name}\") with the shaded jar on spark.jars") + + +def read_repo_version(repo: Path) -> str: + """datafusion-java's maven version, scraped from the parent pom.""" + pom = (repo / "pom.xml").read_text() + m = re.search(r"([^<]+)", pom) + if not m: + sys.exit(f"could not find in {repo}/pom.xml") + return m.group(1) + + +if __name__ == "__main__": + main() diff --git a/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java new file mode 100644 index 0000000..3bcf7ad --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.util.Map; + +/** + * Bridge interface implemented per domain (HDF5, custom Iceberg, an in-house format, etc.). A + * bridge owns its options encoding and a native scan implementation built with {@code + * datafusion_spark_bridge::export_bridge!}; the connector Spark plumbing is generic — it knows only + * this interface. + * + *

The single required method is {@link #scanBackend()}, returning the delegations to the JNI + * class the bridge named in its {@code export_bridge!} invocation. Everything else has a working + * default: {@link #encodeOptions(Map)} encodes the Spark options via {@link OptionsCodec}, and + * {@link #listPartitions(byte[])} reports a single partition. + * + *

Implementations must be no-arg constructable so the Spark connector can instantiate them + * reflectively via {@link Class#forName(String)} on the executor. + */ +public interface BridgeProviderFactory { + + /** + * The native scan implementation this bridge talks to: delegations to the JNI class named in the + * bridge's {@code export_bridge!} invocation, whose generated {@code createScan} builds the + * provider from the options/partition bytes in process. Called wherever the connector needs + * native work — driver-side schema/plan probes and executor-side streams — always on a factory + * freshly instantiated from its class name, so the returned backend never has to be serializable. + */ + ScanBackend scanBackend(); + + /** + * Convert Spark's flat option map to the bridge's encoded options. Driver-side only; the bytes + * ship verbatim through {@code DatafusionInputPartition} and are the scan's identity in + * shared-scan mode (encode deterministically). + * + *

Default: {@link OptionsCodec#encode(Map)} — the key-sorted length-prefixed pair format that + * {@code datafusion_spark_bridge::options} decodes on the Rust side. Override only if the bridge + * already has its own options schema (e.g. a protobuf). + * + * @throws IllegalArgumentException if required options are missing or invalid + */ + default byte[] encodeOptions(Map sparkOptions) { + return OptionsCodec.encode(sparkOptions); + } + + /** + * Enumerate partitions for this dataset. One Spark task is created per returned {@link + * PartitionInfo}. Driver-side only. + * + *

Each partition's {@code partitionBytes} ships verbatim through {@code + * DatafusionInputPartition} to the executor, where it is passed to {@link + * ScanBackend#createScan}. Use it to encode whatever slice metadata (row range, sub-options, file + * offsets, segment id, …) the bridge needs to materialise *that* partition. + * + *

Each partition's {@code preferredLocations} hostnames are returned from {@code + * InputPartition.preferredLocations()} so Spark co-locates the task with the data; empty array = + * no preference. + * + *

Default: one partition ({@code "p0"}, empty payload, no host preference) — one Spark task + * scans the whole dataset. Fine for small tables and first bring-up; override (or opt into {@link + * #sharedScan(byte[])}) before pointing it at anything large. Size guidance lives in {@code + * spark/README.md}. + */ + default PartitionInfo[] listPartitions(byte[] optionsBytes) { + return new PartitionInfo[] {new PartitionInfo("p0", new byte[0], new String[0])}; + } + + /** + * Filter-aware variant of {@link #listPartitions(byte[])}. The connector calls this overload with + * the pushed-down predicates ({@code LogicalExprNode} proto bytes, one array per predicate, same + * encoding the executor later replays via {@link ScanBackend#createScan}). Bridges that can map + * predicates onto their partition layout (e.g. {@code segment_id = 'x'}) should prune partitions + * that cannot match — pruning here eliminates whole Spark tasks, whereas the per-task filter only + * reduces rows inside a task. + * + *

Pruning must be conservative: only drop a partition when NO row in it can satisfy the + * conjunction of all pushed predicates. The default delegates to the filter-unaware overload (no + * pruning), which is always correct. + */ + default PartitionInfo[] listPartitions(byte[] optionsBytes, byte[][] filterProtoBytes) { + return listPartitions(optionsBytes); + } + + /** + * Opt into shared-scan mode for this dataset. Default {@code false} (per-partition payload mode, + * the {@link #listPartitions(byte[])} path). + * + *

When {@code true}, the connector builds ONE provider per (executor JVM × scan) with empty + * {@code partitionBytes}, plans it once, and runs one Spark task per DataFusion output partition + * — task {@code i} streams plan partition {@code i} from the shared, cached plan. This amortises + * provider construction cost across all tasks on an executor and is the right model when the + * dataset has many small partitions or provider construction is expensive (remote metadata, + * connections). {@link #listPartitions(byte[])} and {@link #reportPartitioning(byte[])} are NOT + * called in this mode, and the scan reports {@code UnknownPartitioning} (DataFusion-native + * partitions carry no key contract). + * + *

Determinism contract. The driver counts partitions by planning once; every executor + * re-plans independently and must arrive at the same result. A bridge returning {@code true} + * guarantees: + * + *

    + *
  • The provider's schema, partitioning, and per-partition row content are a pure function of + * {@code optionsBytes}. Remote sources must pin a snapshot (version, timestamp) inside + * the options; data that compacts or moves between driver planning and executor execution + * otherwise yields wrong results that no runtime check can catch. + *
  • The provider's {@code ExecutionPlan} supports calling {@code execute(i)} more than once + * per plan instance (Spark task retry and speculative execution re-execute a partition + * index, sometimes concurrently). Stateless scans satisfy this; single-shot streams do not. + *
+ * + *

The connector fails tasks with a clear error when the executor's partition count diverges + * from the driver's — but identical counts with different contents cannot be detected. + */ + default boolean sharedScan(byte[] optionsBytes) { + return false; + } + + /** + * Declare how rows are partitioned across the {@link PartitionInfo} entries returned by {@link + * #listPartitions(byte[])}. Driver-side only. + * + *

When non-null, the connector surfaces a {@code KeyGroupedPartitioning(keys, + * listPartitions(...).length)} to Spark via {@code SupportsReportPartitioning} so the optimizer + * can elide shuffles ahead of joins/aggregations on the declared keys. + * + *

Default returns {@code null} — no partitioning guarantees, Spark plans as if the scan's + * output ordering and grouping are unknown. + * + *

If a bridge implements this, it must hold the {@link ReportedPartitioning} contract: every + * row in a given partition evaluates to the same tuple of key values under the declared + * transforms. + * + *

Spark 3.3+ caveat: the reported partitioning only takes effect when every {@link + * PartitionInfo} also carries {@link PartitionInfo#partitionKeyValues()} (surfaced to Spark via + * {@code HasPartitionKey}); without key values Spark ignores the declared {@code + * KeyGroupedPartitioning} entirely. Storage-partitioned joins additionally require {@code + * spark.sql.sources.v2.bucketing.enabled=true}. + */ + default ReportedPartitioning reportPartitioning(byte[] optionsBytes) { + return null; + } +} diff --git a/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java new file mode 100644 index 0000000..eb4766a --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Extracts a cdylib bundled inside a jar to a temp file and loads it via {@link System#load}. + * Expected layout inside the jar: + * + *

+ *   <resourcePrefix>/<os>/<arch>/lib<name>.<ext>
+ * 
+ * + * where {@code } is one of {@code linux}, {@code darwin}, {@code windows} and {@code } is + * {@code x86_64} or {@code aarch64}. + * + *

Bridges call {@link #load(Class, String, String)} from their native class's static + * initializer, with their own resource prefix, instead of hand-rolling extraction. Bundle the + * cdylib with the antrun-copy pattern shown in "Packaging your bridge" in {@code spark/README.md}. + */ +public final class NativeLibraryLoader { + + /** {@code /} entries already extracted and loaded by this classloader. */ + private static final Set LOADED = ConcurrentHashMap.newKeySet(); + + private NativeLibraryLoader() {} + + /** + * Extract {@code ///} from {@code anchor}'s classloader + * and {@link System#load} it. Idempotent per (prefix, name): repeated calls — e.g. one per Spark + * task instantiating the bridge's native class — load once. + * + * @param anchor class whose classloader holds the resource (the bridge's own native class, so the + * lookup works under Spark's per-application classloaders) + * @param resourcePrefix jar-internal directory, no leading or trailing slash (e.g. {@code + * "com/example/mybridge"}) + * @param name unmapped library name (e.g. {@code "my_bridge"} for {@code libmy_bridge.so}) + * @throws UnsatisfiedLinkError if the resource is missing or extraction fails + */ + public static void load(Class anchor, String resourcePrefix, String name) { + String key = resourcePrefix + "/" + name; + if (!LOADED.add(key)) { + return; + } + String resource = + String.format( + "/%s/%s/%s/%s", + resourcePrefix, currentOs(), currentArch(), System.mapLibraryName(name)); + try (InputStream in = anchor.getResourceAsStream(resource)) { + if (in == null) { + LOADED.remove(key); + throw new UnsatisfiedLinkError("Native library not found on classpath: " + resource); + } + Path tmp = Files.createTempFile("libdatafusion-spark-", "-" + System.mapLibraryName(name)); + tmp.toFile().deleteOnExit(); + Files.copy(in, tmp, StandardCopyOption.REPLACE_EXISTING); + System.load(tmp.toAbsolutePath().toString()); + } catch (IOException e) { + LOADED.remove(key); + throw new UnsatisfiedLinkError( + "Failed to extract native library " + resource + ": " + e.getMessage()); + } catch (RuntimeException | Error e) { + LOADED.remove(key); + throw e; + } + } + + private static String currentOs() { + String os = System.getProperty("os.name", "").toLowerCase(Locale.ROOT); + if (os.contains("linux")) return "linux"; + if (os.contains("mac") || os.contains("darwin")) return "darwin"; + if (os.contains("windows")) return "windows"; + throw new UnsupportedOperationException("Unsupported OS: " + os); + } + + private static String currentArch() { + String arch = System.getProperty("os.arch", "").toLowerCase(Locale.ROOT); + if (arch.equals("amd64") || arch.equals("x86_64")) return "x86_64"; + if (arch.equals("aarch64") || arch.equals("arm64")) return "aarch64"; + throw new UnsupportedOperationException("Unsupported arch: " + arch); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/OptionsCodec.java b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java new file mode 100644 index 0000000..0d16a28 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.TreeMap; + +/** + * Default wire format for {@link BridgeProviderFactory#encodeOptions(Map)}: the Spark options map + * as length-prefixed UTF-8 pairs, sorted by key. + * + *

Layout (all integers big-endian {@code int32}): entry count, then per entry key length, key + * bytes, value length, value bytes. Key-sorting makes the bytes a pure function of the map's + * contents regardless of source iteration order — required by the shared-scan determinism contract, + * where the options bytes are the cache/plan identity. + * + *

The Rust decoder lives in {@code datafusion_spark_bridge::options}; bridges using the default + * {@code encodeOptions} read their options there as a {@code BTreeMap}. The two + * implementations are pinned to each other by a shared test fixture. + */ +public final class OptionsCodec { + + private OptionsCodec() {} + + /** Encode {@code options} sorted by key. {@code null} or empty map encodes as count 0. */ + public static byte[] encode(Map options) { + TreeMap sorted = options == null ? new TreeMap<>() : new TreeMap<>(options); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeInt(out, sorted.size()); + for (Map.Entry e : sorted.entrySet()) { + if (e.getKey() == null || e.getValue() == null) { + throw new IllegalArgumentException("OptionsCodec does not accept null keys or values"); + } + writeBytes(out, e.getKey().getBytes(StandardCharsets.UTF_8)); + writeBytes(out, e.getValue().getBytes(StandardCharsets.UTF_8)); + } + return out.toByteArray(); + } + + /** Decode bytes produced by {@link #encode(Map)}. Preserves the encoded (sorted) order. */ + public static Map decode(byte[] bytes) { + Map out = new LinkedHashMap<>(); + if (bytes == null || bytes.length == 0) { + return out; + } + ByteBuffer buf = ByteBuffer.wrap(bytes); + int count = readCount(buf, "entry count"); + for (int i = 0; i < count; i++) { + String key = readString(buf, "key of entry " + i); + String value = readString(buf, "value of entry " + i); + out.put(key, value); + } + if (buf.hasRemaining()) { + throw new IllegalArgumentException( + "OptionsCodec: " + buf.remaining() + " trailing byte(s) after " + count + " entries"); + } + return out; + } + + private static void writeInt(ByteArrayOutputStream out, int v) { + out.write((v >>> 24) & 0xFF); + out.write((v >>> 16) & 0xFF); + out.write((v >>> 8) & 0xFF); + out.write(v & 0xFF); + } + + private static void writeBytes(ByteArrayOutputStream out, byte[] bytes) { + writeInt(out, bytes.length); + out.write(bytes, 0, bytes.length); + } + + private static int readCount(ByteBuffer buf, String what) { + if (buf.remaining() < 4) { + throw new IllegalArgumentException("OptionsCodec: truncated " + what); + } + int v = buf.getInt(); + if (v < 0) { + throw new IllegalArgumentException("OptionsCodec: negative " + what + ": " + v); + } + return v; + } + + private static String readString(ByteBuffer buf, String what) { + int len = readCount(buf, "length of " + what); + if (buf.remaining() < len) { + throw new IllegalArgumentException("OptionsCodec: truncated " + what); + } + byte[] bytes = new byte[len]; + buf.get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/PartitionInfo.java b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java new file mode 100644 index 0000000..e6e061b --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +/** + * Driver-side descriptor for a single partition produced by {@link + * BridgeProviderFactory#listPartitions(byte[])}. Carries the bridge-specific slice payload that the + * executor passes back into {@link ScanBackend#createScan}, plus + * optional host hints for Spark's scheduler. + * + *

Fields: + * + *

    + *
  • {@code id} — stable, human-readable identifier for this partition (e.g. a segment id). + * Surfaces in Spark UI, logs, and exception messages. Must be non-empty. + *
  • {@code partitionBytes} — opaque per-partition payload. Bridge encodes whatever the executor + * needs to materialise *this* slice (offsets, row ranges, sub-options, etc.). Combined with + * the global {@code optionsBytes} in {@link ScanBackend#createScan}. Empty array = no + * per-partition state (single-partition table). + *
  • {@code preferredLocations} — hostnames where this partition's data lives. Returned from + * {@code InputPartition.preferredLocations()} so Spark can co-locate the task with the data. + * Empty array = no preference. Honoured subject to {@code spark.locality.wait}. + *
  • {@code partitionKeyValues} — optional values of the partitioning keys for every row in this + * partition, in the same order as {@link BridgeProviderFactory#reportPartitioning(byte[])}'s + * declared transforms. {@code null} = no key (the default). When the bridge reports a + * partitioning AND every partition carries key values, the connector exposes them to Spark + * via {@code HasPartitionKey} — required on Spark 3.3+ for the reported {@code + * KeyGroupedPartitioning} to have any effect (and storage-partitioned joins additionally + * require {@code spark.sql.sources.v2.bucketing.enabled=true}). Values must be Java types + * that Spark's {@code CatalystTypeConverters} can convert for the key columns' data types + * (e.g. {@code String}, {@code Long}, {@code Integer}, {@code java.time.Instant}, {@code + * java.time.LocalDate}, {@code java.math.BigDecimal}), and the array length must equal the + * number of declared keys. + *
+ */ +public record PartitionInfo( + String id, byte[] partitionBytes, String[] preferredLocations, Object[] partitionKeyValues) { + + public PartitionInfo { + if (id == null || id.isEmpty()) { + throw new IllegalArgumentException("PartitionInfo: id must be non-empty"); + } + if (partitionBytes == null) { + partitionBytes = new byte[0]; + } + if (preferredLocations == null) { + preferredLocations = new String[0]; + } + // partitionKeyValues stays null when absent: null and "no key" are the same state, + // and DatafusionBatch distinguishes keyed from unkeyed partitions by it. + } + + /** Without partition key values — the common case. */ + public PartitionInfo(String id, byte[] partitionBytes, String[] preferredLocations) { + this(id, partitionBytes, preferredLocations, null); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java new file mode 100644 index 0000000..639fec9 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.util.Arrays; + +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Transform; + +/** + * Driver-side declaration of how a bridge's data is partitioned on the key columns. When supplied + * via {@link BridgeProviderFactory#reportPartitioning(byte[])}, the connector surfaces a {@link + * org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} from {@link + * org.apache.spark.sql.connector.read.SupportsReportPartitioning#outputPartitioning()} — Spark's + * optimizer can then skip the shuffle ahead of joins/aggregations whose grouping keys line up with + * these transforms. + * + *

Contract: for any partition reported by {@link BridgeProviderFactory#listPartitions(byte[])}, + * every row produced by that partition must evaluate to the same tuple of key values under these + * transforms. Different partitions may share key values (Spark will fuse them); they must + * not straddle key values. + * + *

The partition count Spark sees is {@code listPartitions(...).length}; it is not carried here + * to keep a single source of truth. + */ +public final class ReportedPartitioning { + + private final Transform[] keys; + + public ReportedPartitioning(Transform[] keys) { + if (keys == null || keys.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning: keys must contain at least one transform"); + } + this.keys = keys; + } + + public Transform[] keys() { + return keys; + } + + /** + * Convenience: declare identity partitioning on one or more columns (a row in partition P has the + * same {@code (col1, col2, …)} values as every other row in P). + */ + public static ReportedPartitioning identity(String... columns) { + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.identity: at least one column required"); + } + Transform[] ts = Arrays.stream(columns).map(Expressions::identity).toArray(Transform[]::new); + return new ReportedPartitioning(ts); + } + + /** + * Convenience: declare hash-bucket partitioning. Mirrors Spark's {@code bucket(N, cols…)} + * transform — each row is assigned to bucket {@code hash(cols) mod numBuckets}. + */ + public static ReportedPartitioning bucket(int numBuckets, String... columns) { + if (numBuckets <= 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.bucket: numBuckets must be > 0, got " + numBuckets); + } + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.bucket: at least one column required"); + } + return new ReportedPartitioning(new Transform[] {Expressions.bucket(numBuckets, columns)}); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/ScanBackend.java b/spark/src/main/java/io/datafusion/spark/ScanBackend.java new file mode 100644 index 0000000..a994c98 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/ScanBackend.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +/** + * Native scan surface the connector plumbing talks to: one method per JNI entry point generated by + * the bridge's {@code datafusion_spark_bridge::export_bridge!} invocation. A bridge's + * implementation is six one-line delegations to the JNI class named in that macro, whose {@code + * createScan} builds the provider from {@code options}/{@code partitionBytes} in process. + * + *

Implementations must be stateless or thread-safe: the driver probes schemas and plans through + * one instance while executor tasks stream through others, and scan handles are shared across + * threads by the shared-scan cache. Handle-based methods accept handles produced by {@code + * createScan} on any instance of the same implementation. + */ +public interface ScanBackend { + + /** + * Driver-side schema probe: the widened Arrow schema of the provider described by {@code options} + * + {@code partitionBytes}, serialized as Arrow IPC bytes (deserialize with {@code + * MessageSerializer.deserializeSchema}). + */ + byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes); + + /** + * Build a planned scan and return its handle. {@code targetPartitions}/{@code batchSize} {@code + * <= 0} leave DataFusion defaults; {@code optionKeys}/{@code optionValues} are parallel config + * override arrays; empty {@code projectionColumns} selects all columns; each {@code filterProtos} + * element is a serialized {@code datafusion.LogicalExprNode}. + * + *

The caller owns the handle and must pair it with {@link #closeScan(long)}. Closing while a + * stream opened from the handle is in flight is undefined behaviour — the shared-scan cache's + * refcount enforces this; any other caller must serialize close itself. + */ + long createScan( + byte[] options, + byte[] partitionBytes, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos); + + /** Output partition count of the planned physical plan. */ + int partitionCount(long scanHandle); + + /** + * Open an independent stream over ONE plan partition, writing an {@code FFI_ArrowArrayStream} + * into the caller-allocated struct at {@code ffiStreamAddr}. Concurrent-safe across JVM threads. + */ + void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + + /** + * Stream the WHOLE plan (all partitions coalesced) into the caller-allocated {@code + * FFI_ArrowArrayStream} at {@code ffiStreamAddr}. Used by per-partition mode. + */ + void executeStream(long scanHandle, long ffiStreamAddr); + + /** Drop the planned scan. See {@link #createScan} for the close-vs-in-flight contract. */ + void closeScan(long scanHandle); +} diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..3e612e0 --- /dev/null +++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +io.datafusion.spark.DatafusionSource diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala new file mode 100644 index 0000000..30f62f8 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.vector.FieldVector +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} + +/** + * Shared `next()`/`get()` loop for the connector's columnar readers: each `loadNextBatch()` + * yields a `VectorSchemaRoot` wrapped as a `ColumnarBatch` of [[NonClosingArrowColumnVector]]s + * (the reader owns the vectors; Spark must not close them per batch). + */ +private[spark] trait ArrowColumnarBatchIteration { + + /** The Arrow stream this reader drains. Stable for the reader's lifetime. */ + protected def arrowReader: ArrowReader + + private var currentBatch: ColumnarBatch = _ + + def next(): Boolean = { + if (currentBatch != null) { + currentBatch = null + } + if (!arrowReader.loadNextBatch()) return false + val root = arrowReader.getVectorSchemaRoot + val vectors: java.util.List[FieldVector] = root.getFieldVectors + val cols = new Array[ColumnVector](vectors.size()) + var i = 0 + while (i < vectors.size()) { + cols(i) = new NonClosingArrowColumnVector(vectors.get(i)) + i += 1 + } + val batch = new ColumnarBatch(cols) + batch.setNumRows(root.getRowCount) + currentBatch = batch + true + } + + def get(): ColumnarBatch = currentBatch +} diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala new file mode 100644 index 0000000..2e8f1a5 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.spark.sql.types._ + +/** + * Arrow Schema → Spark StructType converter. + * + * The reported Spark schema MUST be one whose runtime ArrowColumnVector accessor Spark can pick + * for the underlying Arrow vector. Spark 3.5's `ArrowColumnVector` supports the following + * accessors: Boolean, Byte, Short, Int, Long, Float, Double, Decimal, Date, Timestamp, + * TimestampNTZ, Duration (DayTimeInterval), String, LargeString, Binary, Array, Map, Struct, + * Null. No unsigned-int or Time accessor exists; we surface a clear error at schema discovery + * for those — the alternative is silent corruption. + * + * The widening layer (datafusion-spark-bridge, compiled into every bridge cdylib) inserts a + * `WideningTableProvider` upstream of the Spark reader that casts unsupported types kernel-side + * (UInt*→signed wider, Float16→Float32, non-µs Timestamp→µs Timestamp, Time→Int) so Spark only + * ever sees compatible Arrow types. + */ +object ArrowToSparkSchema { + + def toSparkSchema(schema: Schema): StructType = + StructType(schema.getFields.asScala.toSeq.map(toSparkField)) + + private def toSparkField(f: Field): StructField = { + val dt = + Option(f.getDictionary) match { + case Some(_) => + unsupported(f, "dictionary-encoded fields (need dictionary value schema in JNI)") + case None => toSparkType(f) + } + StructField(f.getName, dt, f.isNullable) + } + + private def toSparkType(f: Field): DataType = f.getType match { + case _: ArrowType.Bool => BooleanType + + case t: ArrowType.Int => + (t.getBitWidth, t.getIsSigned) match { + case (8, true) => ByteType + case (16, true) => ShortType + case (32, true) => IntegerType + case (64, true) => LongType + case (bits, false) => + unsupported( + f, + s"unsigned integer UInt$bits (Spark ArrowColumnVector has no unsigned accessor; " + + "widening layer casts these before Spark sees them — this branch indicates the " + + "WideningTableProvider was bypassed)" + ) + case (bits, signed) => unsupported(f, s"Int(bits=$bits, signed=$signed)") + } + + case t: ArrowType.FloatingPoint => + t.getPrecision match { + case FloatingPointPrecision.HALF => + unsupported(f, "Float16 (widening layer must cast to Float32 before Spark)") + case FloatingPointPrecision.SINGLE => FloatType + case FloatingPointPrecision.DOUBLE => DoubleType + case other => unsupported(f, s"FloatingPoint($other)") + } + + case _: ArrowType.Utf8 => StringType + case _: ArrowType.LargeUtf8 => StringType + case _: ArrowType.Binary => BinaryType + case _: ArrowType.LargeBinary => BinaryType + case _: ArrowType.FixedSizeBinary => BinaryType + + case d: ArrowType.Date => + d.getUnit match { + case DateUnit.DAY | DateUnit.MILLISECOND => DateType + case other => unsupported(f, s"Date($other)") + } + + case t: ArrowType.Timestamp => + val _unused = t.getUnit + if (t.getTimezone == null) TimestampNTZType else TimestampType + + case ti: ArrowType.Time => + unsupported( + f, + s"Time(${ti.getUnit}, ${ti.getBitWidth}-bit) — Spark has no time-of-day type" + ) + + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + + case _: ArrowType.Null => NullType + + case _: ArrowType.Duration => DayTimeIntervalType() + + case iv: ArrowType.Interval => + iv.getUnit match { + case IntervalUnit.YEAR_MONTH => YearMonthIntervalType() + case IntervalUnit.DAY_TIME => DayTimeIntervalType() + case IntervalUnit.MONTH_DAY_NANO => + unsupported(f, "Interval(MONTH_DAY_NANO) — no clean Spark equivalent") + } + + case _: ArrowType.Struct => + StructType(f.getChildren.asScala.toSeq.map(toSparkField)) + + case _: ArrowType.List => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + case _: ArrowType.LargeList => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + case _: ArrowType.FixedSizeList => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + + case _: ArrowType.Map => + val entries = f.getChildren.get(0) + val keyValue = entries.getChildren + val keyField = keyValue.get(0) + val valueField = keyValue.get(1) + MapType(toSparkType(keyField), toSparkType(valueField), valueField.isNullable) + + case _: ArrowType.Union => + unsupported(f, "Union (Spark has no equivalent)") + + case other => unsupported(f, s"$other") + } + + private def unsupported(f: Field, detail: String): Nothing = + throw new UnsupportedOperationException( + s"Column '${f.getName}': $detail" + ) +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala new file mode 100644 index 0000000..684e9fd --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} + +/** + * Spark `Batch` for a DataFusion-backed scan. Driver-side partition planning: + * - [[PerPartitionMode]]: one task per `PartitionInfo` (resolved by [[DatafusionScanBuilder]]); when + * the bridge reported a partitioning and every entry carries key values, tasks implement + * `HasPartitionKey` so Spark can actually use the `KeyGroupedPartitioning`. + * - [[SharedScanMode]]: one task per DataFusion plan partition index. + */ +class DatafusionBatch(val scan: DatafusionScan) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val projection = scan.prunedSchema.fieldNames + val filterBytes: Array[Array[Byte]] = scan.pushedPredicateBytes + + scan.mode match { + case PerPartitionMode(partitions, reported) => + val keyed = DatafusionBatch.validateKeyedState(scan.factoryFqcn, partitions, reported) + partitions.iterator.map { p => + val base = DatafusionInputPartition( + factoryFqcn = scan.factoryFqcn, + optionsBytes = scan.optionsBytes, + projectionColumnNames = projection, + filterProtoBytes = filterBytes, + partitionId = p.id, + partitionBytes = p.partitionBytes, + preferredLocs = p.preferredLocations + ) + val out: DatafusionPartition = + if (keyed) { + DatafusionKeyedInputPartition( + base, + DatafusionBatch.toKeyRow(p.id, p.partitionKeyValues, reported)) + } else base + out.asInstanceOf[InputPartition] + }.toArray + + case SharedScanMode(scanId, numPartitions, pinnedConfig, idleTtlMs) => + Array.tabulate[InputPartition](numPartitions) { i => + DatafusionSharedScanPartition( + factoryFqcn = scan.factoryFqcn, + optionsBytes = scan.optionsBytes, + projectionColumnNames = projection, + filterProtoBytes = filterBytes, + scanId = scanId, + partitionIndex = i, + numPartitions = numPartitions, + pinnedConfig = pinnedConfig, + idleTtlMs = idleTtlMs + ) + } + } + } + + override def createReaderFactory(): PartitionReaderFactory = + new DatafusionPartitionReaderFactory(scan.prunedSchema) +} + +private[spark] object DatafusionBatch { + + /** + * Keyed partitions require a reported partitioning AND key values on EVERY partition. A mixed + * state means the bridge violated its own contract; failing driver-side beats Spark silently + * planning without the declared grouping. + */ + def validateKeyedState( + factoryFqcn: String, + partitions: Array[PartitionInfo], + reported: ReportedPartitioning): Boolean = { + if (reported == null) { + return false + } + val withKeys = partitions.count(_.partitionKeyValues != null) + if (withKeys == 0) { + return false + } + if (withKeys != partitions.length) { + throw new IllegalStateException( + s"BridgeProviderFactory '$factoryFqcn' reported a partitioning but only $withKeys of " + + s"${partitions.length} PartitionInfo entries carry partitionKeyValues; either all " + + "partitions must carry key values or none") + } + true + } + + /** + * Convert a bridge-supplied `Object[]` of key values into Spark's internal row representation + * (String → UTF8String, Instant → micros, LocalDate → days, BigDecimal → Decimal, ...). + */ + def toKeyRow( + partitionId: String, + values: Array[AnyRef], + reported: ReportedPartitioning): InternalRow = { + val keyCount = reported.keys().length + if (values.length != keyCount) { + throw new IllegalStateException( + s"PartitionInfo '$partitionId' carries ${values.length} partitionKeyValues but the " + + s"reported partitioning declares $keyCount key(s)") + } + val converted = values.map(v => CatalystTypeConverters.convertToCatalyst(v)) + new GenericInternalRow(converted) + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala new file mode 100644 index 0000000..96b7548 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Per-task columnar reader for the per-partition path. Lifecycle: + * + * 1. Reflectively instantiate the bridge's `BridgeProviderFactory` (no-arg) and take its + * [[ScanBackend]]. + * 2. `backend.createScan(options, partitionBytes, ...)` — builds the provider for the slice + * described by `partitionBytes` and does the rest natively: widening wrap, private + * `SessionContext`, projection, pushed proto filters, physical plan. + * 3. `backend.executeStream` streams the whole plan (the provider already IS the task's + * slice); batches surface through [[ArrowColumnarBatchIteration]]. + */ +class DatafusionColumnarPartitionReader( + partition: DatafusionInputPartition, + readSchema: StructType +) extends PartitionReader[ColumnarBatch] + with ArrowColumnarBatchIteration { + + private val allocator = new RootAllocator(Long.MaxValue) + + private val backend: ScanBackend = instantiateFactory(partition.factoryFqcn).scanBackend() + + private val scanHandle: Long = + try { + backend.createScan( + partition.optionsBytes, + partition.partitionBytes, + /* targetPartitions = */ -1, + /* batchSize = */ -1, + Array.empty[String], + Array.empty[String], + partition.projectionColumnNames, + partition.filterProtoBytes + ) + } catch { + case t: Throwable => + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + + override protected val arrowReader: ArrowReader = + try { + FfiStream.importReader(allocator) { addr => + backend.executeStream(scanHandle, addr) + } + } catch { + case t: Throwable => + try backend.closeScan(scanHandle) + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(arrowReader.close()) + safe(backend.closeScan(scanHandle)) + safe(allocator.close()) + if (first != null) throw first + } + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala new file mode 100644 index 0000000..5255644 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} + +/** + * Marker for the connector's task payloads, shipped driver → executor via Java serialization. + * [[DatafusionPartitionReaderFactory]] dispatches on the concrete type. + */ +sealed trait DatafusionPartition extends InputPartition + +/** + * Per-task payload for the per-partition read path. + * + * - `factoryFqcn`: fully-qualified class name of the bridge's `BridgeProviderFactory`. The + * executor reflectively instantiates this and calls + * `scanBackend().createScan(optionsBytes, partitionBytes, …)`. + * - `optionsBytes`: bridge-specific global connection options, encoded by the bridge. + * Opaque to connector-core. Same bytes ride along on every partition. + * - `projectionColumnNames`: pruned column list (post-`pruneColumns`). + * - `filterProtoBytes`: V2 `Predicate` → DataFusion `LogicalExprNode` proto bytes; each one is + * applied natively via `ScanBackend.createScan`. + * - `partitionId`: stable identifier (e.g. a segment or file id) — surfaces in Spark UI/logs/errors. + * - `partitionBytes`: opaque per-partition payload from `PartitionInfo.partitionBytes`. Passed + * back into `ScanBackend.createScan` so the bridge materialises *this* slice. + * - `preferredLocs`: hostnames where this partition's data lives; returned from + * `preferredLocations()` so Spark schedules the task there subject to `spark.locality.wait`. + */ +final case class DatafusionInputPartition( + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + partitionId: String, + partitionBytes: Array[Byte], + preferredLocs: Array[String] +) extends DatafusionPartition { + + override def preferredLocations(): Array[String] = preferredLocs +} + +/** + * Per-partition payload that additionally carries this partition's key values, precomputed + * driver-side into an [[InternalRow]]. Emitted by [[DatafusionBatch]] when the bridge reported a + * partitioning AND every `PartitionInfo` carries `partitionKeyValues` — implementing + * [[HasPartitionKey]] is what makes the reported `KeyGroupedPartitioning` visible to Spark 3.3+ + * (`DataSourceV2ScanExecBase.groupPartitions` ignores it otherwise). + */ +final case class DatafusionKeyedInputPartition( + base: DatafusionInputPartition, + keyRow: InternalRow +) extends DatafusionPartition + with HasPartitionKey { + + override def preferredLocations(): Array[String] = base.preferredLocations() + + override def partitionKey(): InternalRow = keyRow +} + +/** + * Per-task payload for shared-scan mode: task `partitionIndex` streams that DataFusion plan + * partition from the executor's cached entry (see [[SharedScanCache]]). + * + * - `scanId`: driver-minted UUID identifying this scan; the executor cache key. + * - `partitionIndex`: DataFusion output partition this task drives. + * - `numPartitions`: the driver probe's partition count; executors fail fast when their re-plan + * diverges (determinism guard). + * - `pinnedConfig`: DataFusion session knobs resolved once on the driver and replicated on + * every executor so both plan identically. + * - `idleTtlMs`: cache-entry idle eviction window, resolved from driver conf. + * + * No preferred locations: the shared plan materialises the whole dataset on whichever executors + * Spark picks; there is no per-slice host mapping in this mode. + */ +final case class DatafusionSharedScanPartition( + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + scanId: String, + partitionIndex: Int, + numPartitions: Int, + pinnedConfig: PinnedSessionConfig, + idleTtlMs: Long +) extends DatafusionPartition { + + def toSpec: SharedScanSpec = + SharedScanSpec( + scanId = scanId, + factoryFqcn = factoryFqcn, + optionsBytes = optionsBytes, + projectionColumnNames = projectionColumnNames, + filterProtoBytes = filterProtoBytes, + pinnedConfig = pinnedConfig + ) +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala new file mode 100644 index 0000000..ba5409c --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Per-task `PartitionReader` factory. Columnar-only: row-based reads would force the connector + * to convert Arrow → `InternalRow` per row, defeating the zero-copy path that is the whole + * reason we are in-process. + */ +class DatafusionPartitionReaderFactory(val readSchema: StructType) extends PartitionReaderFactory { + + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException( + "DatafusionPartitionReaderFactory: row-based read not supported; consumers must opt into columnar" + ) + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = + partition match { + case p: DatafusionInputPartition => + new DatafusionColumnarPartitionReader(p, readSchema) + case p: DatafusionKeyedInputPartition => + new DatafusionColumnarPartitionReader(p.base, readSchema) + case p: DatafusionSharedScanPartition => + new SharedScanPartitionReader(p, SharedScanCache.global) + case other => + throw new IllegalArgumentException( + s"unexpected InputPartition type: ${other.getClass.getName}") + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala new file mode 100644 index 0000000..38f0a8b --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Batch, Scan, SupportsReportPartitioning} +import org.apache.spark.sql.connector.read.partitioning.{ + KeyGroupedPartitioning, + Partitioning, + UnknownPartitioning +} +import org.apache.spark.sql.types.StructType + +/** + * How the scan maps to Spark tasks — resolved once, driver-side, in + * [[DatafusionScanBuilder.build]]. + */ +sealed trait DatafusionScanMode extends Serializable + +/** + * Per-partition payload mode: one task per [[PartitionInfo]], each task builds its own provider + * from that entry's `partitionBytes`. `reported` is the bridge's optional partitioning + * declaration (may be null). + */ +final case class PerPartitionMode( + partitions: Array[PartitionInfo], + reported: ReportedPartitioning +) extends DatafusionScanMode + +/** + * Shared-scan mode: one cached provider + plan per (executor × scan), `numPartitions` tasks each + * driving one DataFusion output partition. See [[BridgeProviderFactory#sharedScan]] for the + * determinism contract. + */ +final case class SharedScanMode( + scanId: String, + numPartitions: Int, + pinnedConfig: PinnedSessionConfig, + idleTtlMs: Long +) extends DatafusionScanMode + +/** + * Read plan for a DataFusion-backed scan. Holds pruning state, the pushed predicates (for + * `description()` / `explain(True)`), the corresponding `LogicalExprNode` proto byte arrays the + * executor applies natively via `ScanBackend.createScan`, and the driver-resolved + * [[DatafusionScanMode]]. + * + * Per-partition mode with a bridge-declared [[ReportedPartitioning]] surfaces `KeyGroupedPartitioning` + * via `SupportsReportPartitioning`; note Spark 3.3+ only consumes it when the input partitions + * also implement `HasPartitionKey` (see [[DatafusionBatch]]). Shared-scan mode always reports + * `UnknownPartitioning` — DataFusion-native partitions carry no key contract. + */ +class DatafusionScan( + val factoryFqcn: String, + val optionsBytes: Array[Byte], + val fullSchema: StructType, + val prunedSchema: StructType, + val pushedPredicates: Array[Predicate], + val pushedPredicateBytes: Array[Array[Byte]], + val mode: DatafusionScanMode +) extends Scan + with SupportsReportPartitioning { + + override def readSchema(): StructType = prunedSchema + + override def description(): String = { + val modeDesc = mode match { + case PerPartitionMode(partitions, reported) => + s"mode=per-partition, partitions=${partitions.length}," + + s" reportedPartitioning=${if (reported == null) "unknown" else "key-grouped"}" + case SharedScanMode(scanId, n, _, _) => + s"mode=shared-scan, scanId=$scanId, partitions=$n" + } + s"DatafusionScan(factory=$factoryFqcn, projection=${prunedSchema.fieldNames.mkString(",")}," + + s" pushedPredicates=${pushedPredicates.length}, $modeDesc)" + } + + override def toBatch: Batch = new DatafusionBatch(this) + + override def outputPartitioning(): Partitioning = mode match { + case PerPartitionMode(partitions, reported) => + if (reported == null) new UnknownPartitioning(partitions.length) + else new KeyGroupedPartitioning(reported.keys().toArray, partitions.length) + case SharedScanMode(_, numPartitions, _, _) => + new UnknownPartitioning(numPartitions) + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala new file mode 100644 index 0000000..a74029c --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.UUID + +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * ScanBuilder with V2 Predicate pushdown + column pruning. Every translatable predicate is + * marked Exact and dropped from Spark's post-scan Filter; the rest stay residual. + * + * Pushdown discipline: over-claiming Exact = wrong results, under-claiming = full scans. The + * translator (see [[SparkPredicateTranslator]]) only emits proto for predicates it can encode + * losslessly — anything else returns `None` and lands in residuals. + * + * `build()` resolves the driver-side facts the optimizer needs *before* it starts asking the + * [[DatafusionScan]] about its output partitioning. Spark guarantees `pushPredicates` and + * `pruneColumns` run first, so both paths see the final projection + filters: + * - per-partition payload mode: `listPartitions(opts, filters)` (filter-aware overload — the + * bridge can prune whole partitions) + the optional [[ReportedPartitioning]]; + * - shared-scan mode: a probe build of the provider + plan (via the same code path executors + * use) to count DataFusion output partitions, plus a freshly minted scanId and the pinned + * session config that makes executor re-plans comparable. + */ +class DatafusionScanBuilder( + factoryFqcn: String, + optionsBytes: Array[Byte], + fullSchema: StructType +) extends ScanBuilder + with SupportsPushDownV2Filters + with SupportsPushDownRequiredColumns { + + private var pushed: Array[Predicate] = Array.empty + private var pushedBytes: Array[Array[Byte]] = Array.empty + private var pruned: StructType = fullSchema + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val pushedBuf = scala.collection.mutable.ArrayBuffer[Predicate]() + val bytesBuf = scala.collection.mutable.ArrayBuffer[Array[Byte]]() + val residual = scala.collection.mutable.ArrayBuffer[Predicate]() + + var i = 0 + while (i < predicates.length) { + val p = predicates(i) + SparkPredicateTranslator.translate(p) match { + case Some(node) => + pushedBuf += p + bytesBuf += node.toByteArray + case None => + residual += p + } + i += 1 + } + pushed = pushedBuf.toArray + pushedBytes = bytesBuf.toArray + residual.toArray + } + + override def pushedPredicates(): Array[Predicate] = pushed + + override def pruneColumns(requiredSchema: StructType): Unit = { + pruned = requiredSchema + } + + override def build(): Scan = { + val factory = instantiateFactory(factoryFqcn) + val mode: DatafusionScanMode = + if (factory.sharedScan(optionsBytes)) buildSharedScanMode() + else buildPerPartitionMode(factory) + new DatafusionScan( + factoryFqcn, + optionsBytes, + fullSchema, + pruned, + pushed, + pushedBytes, + mode + ) + } + + private def buildPerPartitionMode(factory: BridgeProviderFactory): PerPartitionMode = { + val partitions: Array[PartitionInfo] = + factory.listPartitions(optionsBytes, pushedBytes) + if (partitions == null || partitions.isEmpty) { + throw new IllegalStateException( + s"BridgeProviderFactory '$factoryFqcn' returned no partitions to scan" + ) + } + PerPartitionMode(partitions, factory.reportPartitioning(optionsBytes)) + } + + /** + * Driver plan probe: build the provider + plan exactly as executors will (same widening, SQL, + * filters, pinned config — one code path in [[NativeSharedScanResources]]) and read the + * physical plan's output partition count. All Spark conf is resolved here, driver-side; + * executors only see the shipped copies. + */ + private def buildSharedScanMode(): SharedScanMode = { + val conf = SQLConf.get + val pinned = PinnedSessionConfig.fromConf(conf) + val idleTtlMs = PinnedSessionConfig.idleTtlMs(conf) + val scanId = UUID.randomUUID().toString + + val probeSpec = SharedScanSpec( + scanId = scanId, + factoryFqcn = factoryFqcn, + optionsBytes = optionsBytes, + projectionColumnNames = pruned.fieldNames, + filterProtoBytes = pushedBytes, + pinnedConfig = pinned + ) + val probe = NativeSharedScanResources.build(probeSpec) + val numPartitions = + try { + probe.partitionCount + } finally { + probe.close() + } + if (numPartitions <= 0) { + throw new IllegalStateException( + s"shared-scan probe for factory '$factoryFqcn' produced a plan with no partitions") + } + SharedScanMode(scanId, numPartitions, pinned, idleTtlMs) + } + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala new file mode 100644 index 0000000..125b3a1 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.io.ByteArrayInputStream +import java.nio.channels.Channels +import java.util + +import org.apache.arrow.vector.ipc.ReadChannel +import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Generic Spark DataSource V2 entry point. Concrete bridges either: + * - Subclass and override [[shortName]] + [[factoryFqcn]] (the short-name shim pattern), or + * - Use this class directly with `option("df.factory", "fully.qualified.FactoryClass")`. + * + * Schema discovery happens driver-side inside the bridge's native scan backend + * (`ScanBackend.providerSchemaIpc`), which widens the provider and returns its Arrow schema as + * IPC bytes. The same `optionsBytes` (and the factory FQCN) is then carried verbatim through + * `DatafusionInputPartition`, so each executor task repeats the same factory → backend pipeline + * locally. + */ +class DatafusionSource extends TableProvider with DataSourceRegister { + + override def shortName(): String = "datafusion" + + /** Spark option key carrying the BridgeProviderFactory FQCN when no override is provided. */ + protected val FactoryOptionKey: String = "df.factory" + + /** + * Resolve the bridge factory class name from the Spark options. Subclasses override to return a + * hard-coded FQCN so users don't need to set `df.factory` themselves. + */ + protected def factoryFqcn(options: CaseInsensitiveStringMap): String = { + val v = options.get(FactoryOptionKey) + if (v == null || v.isEmpty) + throw new IllegalArgumentException( + s"DatafusionSource: option '$FactoryOptionKey' is required when no subclass override is set" + ) + v + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + val fqcn = factoryFqcn(options) + val factory = instantiateFactory(fqcn) + val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap()) + // Schema probe: pass empty partitionBytes — bridges are required to honour an empty + // payload for the driver-side probe (schema must not depend on per-partition state). + val ipcBytes = factory.scanBackend().providerSchemaIpc(optionsBytes, Array.emptyByteArray) + val arrowSchema = MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(ipcBytes)))) + ArrowToSparkSchema.toSparkSchema(arrowSchema) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String] + ): Table = { + val options = new CaseInsensitiveStringMap(properties) + val fqcn = factoryFqcn(options) + val factory = instantiateFactory(fqcn) + val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap()) + new DatafusionTable(fqcn, optionsBytes, schema) + } + + override def supportsExternalMetadata(): Boolean = false + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala new file mode 100644 index 0000000..a0e8ec4 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util + +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Read-only DataFusion-backed table. Capabilities advertise only batch read. + */ +class DatafusionTable( + val factoryFqcn: String, + val optionsBytes: Array[Byte], + val sparkSchema: StructType +) extends Table + with SupportsRead { + + override def name(): String = s"datafusion.${factoryFqcn.split('.').last}" + + override def schema(): StructType = sparkSchema + + override def capabilities(): util.Set[TableCapability] = { + val caps = new util.HashSet[TableCapability]() + caps.add(TableCapability.BATCH_READ) + caps + } + + override def newScanBuilder(scanOpts: CaseInsensitiveStringMap): ScanBuilder = + new DatafusionScanBuilder(factoryFqcn, optionsBytes, sparkSchema) +} diff --git a/spark/src/main/scala/io/datafusion/spark/FfiStream.scala b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala new file mode 100644 index 0000000..eb1149a --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.c.{ArrowArrayStream, Data} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader + +/** + * Arrow C-data import of a native-produced `FFI_ArrowArrayStream`: allocate the empty struct, + * let the native side write into it, then hand it to Arrow Java. On any failure the struct is + * released so a half-written stream can't leak. + */ +private[spark] object FfiStream { + + def importReader(allocator: BufferAllocator)(writeStream: Long => Unit): ArrowReader = { + val stream = ArrowArrayStream.allocateNew(allocator) + try { + writeStream(stream.memoryAddress()) + Data.importArrayStream(allocator, stream) + } catch { + case t: Throwable => + stream.close() + throw t + } + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala new file mode 100644 index 0000000..b541c8a --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.internal.Logging + +/** + * JNI-backed shared-scan entry: one provider, one planned scan handle inside the bridge's native + * scan backend. + * + * The build sequence is the single code path for BOTH the driver-side partition-count probe and + * every executor's cache entry — identical widening, registration, projection, filters, and + * pinned session config are what make the partition count comparable across machines (the + * bridge's determinism contract covers the rest). + */ +private[spark] final class NativeSharedScanResources( + allocator: RootAllocator, + backend: ScanBackend, + scanHandle: Long +) extends SharedScanResources { + + override def partitionCount: Int = backend.partitionCount(scanHandle) + + override def newTaskAllocator(name: String): BufferAllocator = + allocator.newChildAllocator(name, 0, Long.MaxValue) + + override def openPartitionStream( + partition: Int, + taskAllocator: BufferAllocator): ArrowReader = + FfiStream.importReader(taskAllocator) { addr => + backend.executeStreamPartition(scanHandle, partition, addr) + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(backend.closeScan(scanHandle)) + safe(allocator.close()) + if (first != null) throw first + } +} + +private[spark] object NativeSharedScanResources extends Logging { + + def build(spec: SharedScanSpec): SharedScanResources = { + logInfo( + s"Building shared-scan entry for scanId=${spec.scanId} " + + s"(factory=${spec.factoryFqcn}, filters=${spec.filterProtoBytes.length})") + + val factory = Class + .forName(spec.factoryFqcn) + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[BridgeProviderFactory] + val backend = factory.scanBackend() + + val allocator = new RootAllocator(Long.MaxValue) + try { + // Shared mode builds the dataset-wide provider: empty partitionBytes, like the + // driver-side schema probe. DataFusion-native partitioning replaces listPartitions. + val scanHandle = backend.createScan( + spec.optionsBytes, + Array.emptyByteArray, + spec.pinnedConfig.targetPartitions, + spec.pinnedConfig.batchSize, + spec.pinnedConfig.options.map(_._1).toArray, + spec.pinnedConfig.options.map(_._2).toArray, + spec.projectionColumnNames, + spec.filterProtoBytes + ) + new NativeSharedScanResources(allocator, backend, scanHandle) + } catch { + case t: Throwable => + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala new file mode 100644 index 0000000..4fa74bd --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.vector.FieldVector +import org.apache.spark.sql.vectorized.ArrowColumnVector + +/** + * `ArrowColumnVector` whose `close()` is a no-op. The `ArrowReader`'s `VectorSchemaRoot` owns + * the underlying `ValueVector` lifecycles across `loadNextBatch()` calls; closing them per Spark + * batch would break the next read. Lifecycle is centralised in + * `DatafusionColumnarPartitionReader.close()`. + */ +final class NonClosingArrowColumnVector(vec: FieldVector) extends ArrowColumnVector(vec) { + override def close(): Unit = { /* intentional no-op */ } +} diff --git a/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala new file mode 100644 index 0000000..1340978 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.internal.SQLConf + +/** + * DataFusion session knobs pinned by the driver and replicated verbatim on every executor in + * shared-scan mode. + * + * DataFusion's default `SessionConfig` derives `target_partitions` from the machine's core count, + * so a plan that yields N partitions on the driver could yield M ≠ N on a differently-sized + * executor — and partition-indexed execution would silently drop or duplicate data. The driver + * resolves these values once in `DatafusionScanBuilder.build()`, ships them inside every + * [[DatafusionSharedScanPartition]], and both the driver probe and the executors hand the same + * values to `ScanBackend.createScan`, which builds the native `SessionContext` from them. + * + * `options` additionally disables the optimizer's plan-reshaping repartition passes so the + * physical partitioning is exactly what the provider's `scan()` reports, on every machine. + */ +final case class PinnedSessionConfig( + targetPartitions: Int, + batchSize: Int, + options: Vector[(String, String)] +) extends Serializable + +object PinnedSessionConfig { + + val TargetPartitionsConf = "spark.datafusion.sharedScan.targetPartitions" + val BatchSizeConf = "spark.datafusion.sharedScan.batchSize" + val IdleTtlConf = "spark.datafusion.sharedScan.idleTtlMs" + + val DefaultTargetPartitions = 8 + val DefaultBatchSize = 8192 + val DefaultIdleTtlMs = 120000L + + /** + * Optimizer knobs that must not vary with the host. Round-robin repartition and file-scan + * repartition would let the optimizer change the plan's output partition count based on + * `target_partitions` heuristics; statistics collection could steer per-host plan differences. + */ + private val DeterminismOptions: Vector[(String, String)] = Vector( + "datafusion.optimizer.enable_round_robin_repartition" -> "false", + "datafusion.optimizer.repartition_file_scans" -> "false", + "datafusion.execution.collect_statistics" -> "false" + ) + + /** + * Resolve the pinned config from the driver's session conf. Called exactly once per scan, on + * the driver; executors never read Spark conf for these values — they use the shipped copy. + */ + def fromConf(conf: SQLConf): PinnedSessionConfig = { + PinnedSessionConfig( + targetPartitions = + conf.getConfString(TargetPartitionsConf, DefaultTargetPartitions.toString).toInt, + batchSize = conf.getConfString(BatchSizeConf, DefaultBatchSize.toString).toInt, + options = DeterminismOptions + ) + } + + def idleTtlMs(conf: SQLConf): Long = + conf.getConfString(IdleTtlConf, DefaultIdleTtlMs.toString).toLong +} diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala new file mode 100644 index 0000000..a134746 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.concurrent.{ConcurrentHashMap, Executors, ScheduledExecutorService, TimeUnit} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader + +/** + * Everything the driver resolved that an executor needs to rebuild the shared scan: identity + * (scanId) plus the exact build inputs (factory, options, projection, filters, pinned config). + */ +final case class SharedScanSpec( + scanId: String, + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + pinnedConfig: PinnedSessionConfig +) + +/** + * What one cached shared-scan entry exposes to readers. Implemented by + * [[NativeSharedScanResources]] (JNI-backed) and by fakes in tests. + */ +trait SharedScanResources extends AutoCloseable { + + /** Output partition count of the planned physical plan. */ + def partitionCount: Int + + /** Child allocator for one task's reader; closed by the task, before release. */ + def newTaskAllocator(name: String): BufferAllocator + + /** Open an independent stream over one plan partition. Concurrent-safe. */ + def openPartitionStream(partition: Int, taskAllocator: BufferAllocator): ArrowReader +} + +/** + * Executor-JVM cache of shared-scan entries, keyed by the driver-minted scanId. + * + * Semantics: + * - `acquire` builds the entry exactly once per attempt wave: the first caller builds under + * the entry's lock, concurrent callers block and share the result. Each successful acquire + * increments a refcount that the caller MUST pair with `release(scanId)`. + * - Build failures propagate to the builder AND all waiters of that attempt, and are not + * cached: the next acquire rebuilds. + * - Eviction closes entries with refcount 0 that have been idle longer than their TTL. The + * refcount covers every open reader, so native close never races an in-flight stream. + * Acquire after eviction rebuilds — correct, just slower. + * + * The cache itself is JNI-free: the entry builder is injected, so tests run without native libs. + */ +final class SharedScanCache( + buildEntry: SharedScanSpec => SharedScanResources, + nanoClock: () => Long = () => System.nanoTime() +) { + + /** + * Per-scanId slot. All state transitions are guarded by `this` (the holder's monitor); the + * build itself also runs under the monitor, which is what blocks concurrent acquirers of the + * same scan until the entry exists. + */ + private final class EntryHolder(spec: SharedScanSpec, idleTtlMs: Long) { + private var resources: SharedScanResources = _ + private var refCount: Int = 0 + private var lastReleaseNanos: Long = nanoClock() + private var closed: Boolean = false + + /** Returns the resources with refcount incremented, or None if this holder was evicted. */ + def acquire(): Option[SharedScanResources] = synchronized { + if (closed) return None + if (resources == null) { + resources = buildEntry(spec) // throws -> caller removes holder + } + refCount += 1 + Some(resources) + } + + def release(): Unit = synchronized { + refCount -= 1 + lastReleaseNanos = nanoClock() + } + + /** Close if idle past TTL; returns true when this holder is now closed. */ + def closeIfIdle(nowNanos: Long): Boolean = synchronized { + if (closed) return true + val idle = refCount == 0 && + (nowNanos - lastReleaseNanos) >= TimeUnit.MILLISECONDS.toNanos(idleTtlMs) + if (idle) forceCloseLocked() + closed + } + + def forceClose(): Unit = synchronized { forceCloseLocked() } + + private def forceCloseLocked(): Unit = { + if (!closed) { + closed = true + if (resources != null) { + val r = resources + resources = null + r.close() + } + } + } + } + + private val entries = new ConcurrentHashMap[String, EntryHolder]() + + def acquire(spec: SharedScanSpec, idleTtlMs: Long): SharedScanResources = { + while (true) { + val holder = + entries.computeIfAbsent(spec.scanId, _ => new EntryHolder(spec, idleTtlMs)) + val acquired = + try { + holder.acquire() + } catch { + case t: Throwable => + // Build failed: drop the holder so the next acquire rebuilds, then propagate. + entries.remove(spec.scanId, holder) + throw t + } + acquired match { + case Some(resources) => return resources + case None => + // Holder was evicted between map lookup and acquire; retry with a fresh one. + entries.remove(spec.scanId, holder) + } + } + throw new IllegalStateException("unreachable") + } + + def release(scanId: String): Unit = { + val holder = entries.get(scanId) + if (holder == null) { + throw new IllegalStateException( + s"release($scanId) without a cached entry: unbalanced acquire/release") + } + holder.release() + } + + /** Close and remove every idle-past-TTL entry. Called by the evictor daemon and by tests. */ + private[spark] def evictIdleNow(): Unit = { + val now = nanoClock() + entries.forEach { (scanId, holder) => + if (holder.closeIfIdle(now)) { + entries.remove(scanId, holder) + } + } + } + + /** Close everything regardless of refcounts. JVM-shutdown path only. */ + def shutdown(): Unit = { + entries.forEach { (_, holder) => holder.forceClose() } + entries.clear() + } +} + +object SharedScanCache { + + /** Evictor period. Short relative to any sane TTL; cheap when the map is empty. */ + private val EvictorPeriodMs = 5000L + + /** JVM singleton used by executor tasks. Lazily started together with its evictor daemon. */ + lazy val global: SharedScanCache = { + val cache = new SharedScanCache(NativeSharedScanResources.build) + val evictor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor { r => + val t = new Thread(r, "datafusion-shared-scan-evictor") + t.setDaemon(true) + t + } + evictor.scheduleWithFixedDelay( + () => cache.evictIdleNow(), + EvictorPeriodMs, + EvictorPeriodMs, + TimeUnit.MILLISECONDS) + Runtime.getRuntime.addShutdownHook(new Thread(() => cache.shutdown())) + cache + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala new file mode 100644 index 0000000..4f0c9c1 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.TaskContext +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Shared-scan task reader: acquires the executor's cached (provider, plan) entry and streams ONE + * DataFusion plan partition from it. The acquire/release refcount pair brackets the reader's + * whole lifetime, so the cache can never close the native plan under an open stream. + */ +class SharedScanPartitionReader( + partition: DatafusionSharedScanPartition, + cache: SharedScanCache +) extends PartitionReader[ColumnarBatch] + with ArrowColumnarBatchIteration { + + private val resources: SharedScanResources = cache.acquire(partition.toSpec, partition.idleTtlMs) + + // Determinism guard: the driver counted partitions by planning once; if this executor's + // re-plan disagrees, partition indices are meaningless and every task of the scan must fail + // rather than silently drop or duplicate data. + if (resources.partitionCount != partition.numPartitions) { + val executorCount = resources.partitionCount + cache.release(partition.scanId) + throw new IllegalStateException( + s"shared-scan determinism violation for scanId=${partition.scanId}: driver planned " + + s"${partition.numPartitions} partition(s) but this executor planned $executorCount. " + + "The provider's partitioning must be a pure function of optionsBytes; pin your " + + "source snapshot (see BridgeProviderFactory.sharedScan).") + } + + private val taskAllocator: BufferAllocator = { + val attempt = Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + resources.newTaskAllocator( + s"shared-${partition.scanId}-p${partition.partitionIndex}-attempt$attempt") + } + + override protected val arrowReader: ArrowReader = + try { + resources.openPartitionStream(partition.partitionIndex, taskAllocator) + } catch { + case t: Throwable => + try taskAllocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + cache.release(partition.scanId) + throw t + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(arrowReader.close()) + safe(taskAllocator.close()) + // Release LAST: the refcount must cover the open stream and the task allocator. + safe(cache.release(partition.scanId)) + if (first != null) throw first + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala new file mode 100644 index 0000000..3092914 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import datafusion_common.DatafusionCommon.{Column, ScalarValue} +import org.apache.datafusion.protobuf.{ + BinaryExprNode, + InListNode, + IsNotNull, + IsNull, + LikeNode, + LogicalExprNode, + Not => NotNode +} +import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate + +/** + * Translate Spark V2 `Predicate` → DataFusion `LogicalExprNode` proto. Only emits expressions + * that the producer can apply EXACTLY — anything else returns `None` and the caller marks the + * predicate as residual so Spark re-applies it above the scan. + */ +object SparkPredicateTranslator { + + def translate(p: Predicate): Option[LogicalExprNode] = p.name() match { + case "=" => binary(p, "Eq") + case "<>" => binary(p, "NotEq") + case "<" => binary(p, "Lt") + case "<=" => binary(p, "LtEq") + case ">" => binary(p, "Gt") + case ">=" => binary(p, "GtEq") + case "IS_NULL" => unary(p, "IsNull") + case "IS_NOT_NULL" => unary(p, "IsNotNull") + case "AND" => combine(p, "And") + case "OR" => combine(p, "Or") + case "NOT" => translateNot(p) + case "IN" => translateIn(p) + case "STARTS_WITH" => like(p, prefix = false, suffix = true) + case "ENDS_WITH" => like(p, prefix = true, suffix = false) + case "CONTAINS" => like(p, prefix = true, suffix = true) + case _ => None + } + + private def binary(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val left = expr(cs(0)) + val right = expr(cs(1)) + if (left.isEmpty || right.isEmpty) return None + Some( + LogicalExprNode + .newBuilder() + .setBinaryExpr( + BinaryExprNode + .newBuilder() + .addOperands(left.get) + .addOperands(right.get) + .setOp(op) + .build() + ) + .build() + ) + } + + private def unary(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 1) return None + val inner = expr(cs(0)) + if (inner.isEmpty) return None + val builder = LogicalExprNode.newBuilder() + op match { + case "IsNull" => builder.setIsNullExpr(IsNull.newBuilder().setExpr(inner.get).build()) + case "IsNotNull" => + builder.setIsNotNullExpr(IsNotNull.newBuilder().setExpr(inner.get).build()) + case _ => return None + } + Some(builder.build()) + } + + private def combine(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val (l, r) = (cs(0), cs(1)) + if (!l.isInstanceOf[Predicate] || !r.isInstanceOf[Predicate]) return None + val ln = translate(l.asInstanceOf[Predicate]) + val rn = translate(r.asInstanceOf[Predicate]) + if (ln.isEmpty || rn.isEmpty) return None + Some( + LogicalExprNode + .newBuilder() + .setBinaryExpr( + BinaryExprNode + .newBuilder() + .addOperands(ln.get) + .addOperands(rn.get) + .setOp(op) + .build() + ) + .build() + ) + } + + private def translateNot(p: Predicate): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 1 || !cs(0).isInstanceOf[Predicate]) return None + val inner = translate(cs(0).asInstanceOf[Predicate]) + if (inner.isEmpty) return None + Some(LogicalExprNode.newBuilder().setNotExpr(NotNode.newBuilder().setExpr(inner.get).build()).build()) + } + + private def translateIn(p: Predicate): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length < 2) return None + val target = expr(cs(0)) + if (target.isEmpty) return None + val values = new java.util.ArrayList[LogicalExprNode]() + var i = 1 + while (i < cs.length) { + val v = expr(cs(i)) + if (v.isEmpty) return None + values.add(v.get) + i += 1 + } + val node = InListNode + .newBuilder() + .setExpr(target.get) + .addAllList(values) + .setNegated(false) + .build() + Some(LogicalExprNode.newBuilder().setInList(node).build()) + } + + private def like(p: Predicate, prefix: Boolean, suffix: Boolean): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val target = expr(cs(0)) + val pat = cs(1) match { + case lit: Literal[_] => + val raw = lit.value() match { + case s: String => Some(s) + case u: org.apache.spark.unsafe.types.UTF8String => Some(u.toString) + case _ => None + } + raw.map { r => + val escaped = r.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + (if (prefix) "%" else "") + escaped + (if (suffix) "%" else "") + } + case _ => None + } + if (target.isEmpty || pat.isEmpty) return None + val patternExpr = stringLiteral(pat.get) + val like = LikeNode + .newBuilder() + .setExpr(target.get) + .setPattern(patternExpr) + .setNegated(false) + .setEscapeChar("\\") + .build() + Some(LogicalExprNode.newBuilder().setLike(like).build()) + } + + private def expr(e: Expression): Option[LogicalExprNode] = e match { + case nr: NamedReference => + val parts = nr.fieldNames() + if (parts.length != 1) None + else + Some( + LogicalExprNode + .newBuilder() + .setColumn(Column.newBuilder().setName(parts(0)).build()) + .build() + ) + case lit: Literal[_] => literal(lit.value()) + case _ => None + } + + private def literal(v: Any): Option[LogicalExprNode] = { + val sv = ScalarValue.newBuilder() + val ok: Boolean = v match { + case b: java.lang.Boolean => sv.setBoolValue(b.booleanValue()); true + case b: java.lang.Byte => sv.setInt8Value(b.intValue()); true + case s: java.lang.Short => sv.setInt16Value(s.intValue()); true + case i: java.lang.Integer => sv.setInt32Value(i.intValue()); true + case l: java.lang.Long => sv.setInt64Value(l.longValue()); true + case f: java.lang.Float => sv.setFloat32Value(f.floatValue()); true + case d: java.lang.Double => sv.setFloat64Value(d.doubleValue()); true + case s: String => sv.setUtf8Value(s); true + case u: org.apache.spark.unsafe.types.UTF8String => sv.setUtf8Value(u.toString); true + case _ => false + } + if (!ok) None + else Some(LogicalExprNode.newBuilder().setLiteral(sv.build()).build()) + } + + private def stringLiteral(s: String): LogicalExprNode = + LogicalExprNode.newBuilder().setLiteral(ScalarValue.newBuilder().setUtf8Value(s).build()).build() +} diff --git a/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala new file mode 100644 index 0000000..2b59601 --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.Collections + +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +class ArrowToSparkSchemaTest extends AnyFunSuite { + + private def primField(name: String, t: ArrowType, nullable: Boolean = true): Field = + new Field(name, new FieldType(nullable, t, /*dict=*/ null), Collections.emptyList()) + + test("signed ints map to matching Spark int types") { + val arrow = new Schema( + java.util.Arrays.asList( + primField("i8", new ArrowType.Int(8, true)), + primField("i16", new ArrowType.Int(16, true)), + primField("i32", new ArrowType.Int(32, true)), + primField("i64", new ArrowType.Int(64, true)) + ) + ) + val s = ArrowToSparkSchema.toSparkSchema(arrow) + assert(s.fields(0).dataType == ByteType) + assert(s.fields(1).dataType == ShortType) + assert(s.fields(2).dataType == IntegerType) + assert(s.fields(3).dataType == LongType) + } + + test("unsigned ints are rejected with a clear error") { + val arrow = new Schema( + java.util.Arrays.asList(primField("u32", new ArrowType.Int(32, false))) + ) + val ex = intercept[UnsupportedOperationException](ArrowToSparkSchema.toSparkSchema(arrow)) + assert(ex.getMessage.contains("u32")) + assert(ex.getMessage.toLowerCase.contains("unsigned")) + } + + test("timestamps split on timezone presence") { + val withTz = primField("t_utc", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")) + val noTz = primField("t_local", new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)) + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(withTz, noTz)) + ) + assert(s.fields(0).dataType == TimestampType) + assert(s.fields(1).dataType == TimestampNTZType) + } + + test("decimal preserves precision and scale") { + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("d", new ArrowType.Decimal(18, 4, 128)))) + ) + assert(s.fields(0).dataType == DecimalType(18, 4)) + } + + test("Time and Float16 are rejected (no Spark accessor)") { + intercept[UnsupportedOperationException] { + ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("t", new ArrowType.Time(TimeUnit.MICROSECOND, 64)))) + ) + } + intercept[UnsupportedOperationException] { + ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("h", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))) + ) + } + } + + test("list element nullability propagates") { + val child = + new Field( + "el", + new FieldType(/*nullable=*/ true, new ArrowType.Int(32, true), null), + Collections.emptyList() + ) + val listField = new Field( + "xs", + new FieldType(true, new ArrowType.List(), null), + java.util.Arrays.asList(child) + ) + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(listField)) + ) + assert(s.fields(0).dataType == ArrayType(IntegerType, containsNull = true)) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala new file mode 100644 index 0000000..0b94eee --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.scalatest.funsuite.AnyFunSuite + +class BridgeProviderFactoryDefaultsTest extends AnyFunSuite { + + /** Backend stub: the defaults under test never touch native code. */ + private object StubBackend extends ScanBackend { + def providerSchemaIpc(options: Array[Byte], partitionBytes: Array[Byte]): Array[Byte] = + throw new UnsupportedOperationException + def createScan( + options: Array[Byte], + partitionBytes: Array[Byte], + targetPartitions: Int, + batchSize: Int, + optionKeys: Array[String], + optionValues: Array[String], + projectionColumns: Array[String], + filterProtos: Array[Array[Byte]]): Long = throw new UnsupportedOperationException + def partitionCount(scanHandle: Long): Int = throw new UnsupportedOperationException + def executeStreamPartition(scanHandle: Long, partition: Int, ffiStreamAddr: Long): Unit = + throw new UnsupportedOperationException + def executeStream(scanHandle: Long, ffiStreamAddr: Long): Unit = + throw new UnsupportedOperationException + def closeScan(scanHandle: Long): Unit = throw new UnsupportedOperationException + } + + /** Factory overriding only listPartitions (to spy on its inputs). */ + private class MinimalFactory extends BridgeProviderFactory { + var lastListPartitionsOpts: Array[Byte] = _ + + override def scanBackend(): ScanBackend = StubBackend + + override def listPartitions(optionsBytes: Array[Byte]): Array[PartitionInfo] = { + lastListPartitionsOpts = optionsBytes + Array(new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String])) + } + } + + /** Only the required method implemented — the literal minimum a bridge can ship. */ + private class EmptyFactory extends BridgeProviderFactory { + override def scanBackend(): ScanBackend = StubBackend + } + + test("sharedScan defaults to false") { + assert(!new MinimalFactory().sharedScan(Array[Byte](1, 2, 3))) + } + + test("default encodeOptions uses OptionsCodec") { + val opts = new java.util.HashMap[String, String]() + opts.put("url", "grpc://h:1") + val bytes = new EmptyFactory().encodeOptions(opts) + assert(java.util.Arrays.equals(bytes, OptionsCodec.encode(opts))) + assert(OptionsCodec.decode(bytes).get("url") == "grpc://h:1") + } + + test("default listPartitions reports a single whole-dataset partition") { + val partitions = new EmptyFactory().listPartitions(Array[Byte](1)) + assert(partitions.length == 1) + assert(partitions(0).id == "p0") + assert(partitions(0).partitionBytes().isEmpty) + assert(partitions(0).preferredLocations().isEmpty) + } + + test("filter-aware listPartitions delegates to the filter-unaware overload") { + val factory = new MinimalFactory + val opts = Array[Byte](7, 8) + val filters = Array(Array[Byte](1), Array[Byte](2)) + val partitions = factory.listPartitions(opts, filters) + assert(partitions.length == 1) + assert(partitions(0).id == "p0") + assert(factory.lastListPartitionsOpts eq opts) + } + + test("reportPartitioning defaults to null") { + assert(new MinimalFactory().reportPartitioning(Array.emptyByteArray) == null) + } + + test("PartitionInfo 3-arg constructor leaves partitionKeyValues null") { + val p = new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String]) + assert(p.partitionKeyValues() == null) + } + + test("PartitionInfo 4-arg constructor carries key values") { + val p = new PartitionInfo( + "p0", + Array.emptyByteArray, + Array.empty[String], + Array[AnyRef]("segment-a", Long.box(42L))) + assert(p.partitionKeyValues().length == 2) + assert(p.partitionKeyValues()(0) == "segment-a") + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala new file mode 100644 index 0000000..59f6c8f --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.scalatest.funsuite.AnyFunSuite + +class OptionsCodecTest extends AnyFunSuite { + + /** + * Shared fixture: must stay byte-identical to the one asserted by the Rust-side + * `datafusion_spark_bridge::options` tests. {"table": "t1", "url": "grpc://h:1"} encodes + * (sorted: table < url) as below. + */ + private def fixtureBytes(): Array[Byte] = { + val out = new ByteArrayOutputStream() + def writeInt(v: Int): Unit = { + out.write((v >>> 24) & 0xFF); out.write((v >>> 16) & 0xFF) + out.write((v >>> 8) & 0xFF); out.write(v & 0xFF) + } + def writeString(s: String): Unit = { + val b = s.getBytes(StandardCharsets.UTF_8) + writeInt(b.length) + out.write(b, 0, b.length) + } + writeInt(2) + Seq("table" -> "t1", "url" -> "grpc://h:1").foreach { case (k, v) => + writeString(k); writeString(v) + } + out.toByteArray + } + + test("encodes the cross-language fixture byte-identically, sorted by key") { + // Insertion order deliberately unsorted; encoding must sort. + val opts = new java.util.LinkedHashMap[String, String]() + opts.put("url", "grpc://h:1") + opts.put("table", "t1") + assert(java.util.Arrays.equals(OptionsCodec.encode(opts), fixtureBytes())) + } + + test("round-trips including unicode values") { + val opts = new java.util.HashMap[String, String]() + opts.put("a", "1") + opts.put("unicode", "héllo→world") + val decoded = OptionsCodec.decode(OptionsCodec.encode(opts)) + assert(decoded.size() == 2) + assert(decoded.get("unicode") == "héllo→world") + } + + test("null and empty maps encode to a zero count and decode back empty") { + assert(OptionsCodec.decode(OptionsCodec.encode(null)).isEmpty) + assert(OptionsCodec.decode(Array.emptyByteArray).isEmpty) + } + + test("rejects truncation and trailing bytes") { + val bytes = fixtureBytes() + intercept[IllegalArgumentException] { + OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length - 1)) + } + intercept[IllegalArgumentException] { + OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length + 1)) + } + } + + test("rejects null keys or values") { + val opts = new java.util.HashMap[String, String]() + opts.put("k", null) + intercept[IllegalArgumentException] { OptionsCodec.encode(opts) } + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala new file mode 100644 index 0000000..e2f876d --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.funsuite.AnyFunSuite + +class PartitionKeyConversionTest extends AnyFunSuite { + + private def info(id: String, keys: Array[AnyRef]): PartitionInfo = + new PartitionInfo(id, Array.emptyByteArray, Array.empty[String], keys) + + private def infoNoKeys(id: String): PartitionInfo = + new PartitionInfo(id, Array.emptyByteArray, Array.empty[String]) + + test("String and Long key values convert to catalyst representations") { + val reported = ReportedPartitioning.identity("segment_id", "bucket") + val row = + DatafusionBatch.toKeyRow("p0", Array[AnyRef]("segment-a", Long.box(42L)), reported) + assert(row.numFields == 2) + assert(row.get(0, org.apache.spark.sql.types.StringType) == UTF8String.fromString("segment-a")) + assert(row.getLong(1) == 42L) + } + + test("arity mismatch between key values and declared keys throws") { + val reported = ReportedPartitioning.identity("segment_id", "bucket") + val e = intercept[IllegalStateException] { + DatafusionBatch.toKeyRow("p0", Array[AnyRef]("only-one"), reported) + } + assert(e.getMessage.contains("declares 2 key(s)")) + } + + test("keyed state requires reported partitioning") { + val partitions = Array(info("p0", Array[AnyRef]("a"))) + assert(!DatafusionBatch.validateKeyedState("F", partitions, null)) + } + + test("no partitions with keys means unkeyed, even with reported partitioning") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = Array(infoNoKeys("p0"), infoNoKeys("p1")) + assert(!DatafusionBatch.validateKeyedState("F", partitions, reported)) + } + + test("all partitions with keys means keyed") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = + Array(info("p0", Array[AnyRef]("a")), info("p1", Array[AnyRef]("b"))) + assert(DatafusionBatch.validateKeyedState("F", partitions, reported)) + } + + test("mixed keyed and unkeyed partitions throw driver-side") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = Array(info("p0", Array[AnyRef]("a")), infoNoKeys("p1")) + val e = intercept[IllegalStateException] { + DatafusionBatch.validateKeyedState("F", partitions, reported) + } + assert(e.getMessage.contains("only 1 of 2")) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala new file mode 100644 index 0000000..dae49eb --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.scalatest.funsuite.AnyFunSuite + +class SharedScanCacheTest extends AnyFunSuite { + + private def spec(scanId: String): SharedScanSpec = + SharedScanSpec( + scanId = scanId, + factoryFqcn = "test.Factory", + optionsBytes = Array.emptyByteArray, + projectionColumnNames = Array.empty, + filterProtoBytes = Array.empty, + pinnedConfig = PinnedSessionConfig(8, 8192, Vector.empty) + ) + + /** JNI-free fake entry; records close. */ + private final class FakeResources extends SharedScanResources { + @volatile var closed = false + override def partitionCount: Int = 3 + override def newTaskAllocator(name: String): BufferAllocator = + throw new UnsupportedOperationException("not used in cache tests") + override def openPartitionStream(p: Int, a: BufferAllocator): ArrowReader = + throw new UnsupportedOperationException("not used in cache tests") + override def close(): Unit = closed = true + } + + private final class Fixture { + val clock = new AtomicLong(0L) + val buildCount = new AtomicInteger(0) + var failBuilds = false + var lastBuilt: FakeResources = _ + + val cache = new SharedScanCache( + buildEntry = _ => { + buildCount.incrementAndGet() + if (failBuilds) throw new RuntimeException("synthetic build failure") + lastBuilt = new FakeResources + lastBuilt + }, + nanoClock = () => clock.get() + ) + + def advanceMillis(ms: Long): Unit = clock.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms)) + } + + test("acquire builds once, second acquire reuses, refcount pairs with release") { + val f = new Fixture + val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val r2 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 1) + assert(r1 eq r2) + f.cache.release("s1") + f.cache.release("s1") + } + + test("concurrent acquires build exactly once") { + val f = new Fixture + val n = 8 + val pool = Executors.newFixedThreadPool(n) + val ready = new CountDownLatch(n) + val go = new CountDownLatch(1) + try { + val futures = (0 until n).map { _ => + pool.submit { () => + ready.countDown() + go.await() + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + } + } + ready.await() + go.countDown() + val results = futures.map(_.get(10, TimeUnit.SECONDS)) + assert(f.buildCount.get() == 1) + assert(results.forall(_ eq results.head)) + (0 until n).foreach(_ => f.cache.release("s1")) + } finally { + pool.shutdownNow() + } + } + + test("build failure propagates and is not cached") { + val f = new Fixture + f.failBuilds = true + val e = intercept[RuntimeException](f.cache.acquire(spec("s1"), idleTtlMs = 1000)) + assert(e.getMessage == "synthetic build failure") + f.failBuilds = false + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r eq f.lastBuilt) + f.cache.release("s1") + } + + test("idle entry past TTL is evicted and closed") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + val built = f.lastBuilt + f.advanceMillis(999) + f.cache.evictIdleNow() + assert(!built.closed) + f.advanceMillis(2) + f.cache.evictIdleNow() + assert(built.closed) + } + + test("entry in use is never evicted, regardless of idle time") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val built = f.lastBuilt + f.advanceMillis(100000) + f.cache.evictIdleNow() + assert(!built.closed) + f.cache.release("s1") + f.advanceMillis(100000) + f.cache.evictIdleNow() + assert(built.closed) + } + + test("release then reacquire within TTL resets idleness") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + f.advanceMillis(900) + // Next task wave lands before TTL: same entry, no rebuild. + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 1) + assert(r eq f.lastBuilt) + f.cache.release("s1") + f.advanceMillis(900) + f.cache.evictIdleNow() + assert(!f.lastBuilt.closed, "idle clock must restart at the last release") + } + + test("acquire after eviction rebuilds") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + val first = f.lastBuilt + f.advanceMillis(2000) + f.cache.evictIdleNow() + assert(first.closed) + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r ne first) + f.cache.release("s1") + } + + test("distinct scanIds get distinct entries") { + val f = new Fixture + val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val r2 = f.cache.acquire(spec("s2"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r1 ne r2) + f.cache.release("s1") + f.cache.release("s2") + } + + test("unbalanced release throws") { + val f = new Fixture + intercept[IllegalStateException](f.cache.release("never-acquired")) + } + + test("shutdown closes everything, even entries in use") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val built = f.lastBuilt + f.cache.shutdown() + assert(built.closed) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala new file mode 100644 index 0000000..b7faac1 --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.datafusion.protobuf.LogicalExprNode +import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.scalatest.funsuite.AnyFunSuite + +class SparkPredicateTranslatorTest extends AnyFunSuite { + + private def col(name: String): NamedReference = Expressions.column(name) + private def litInt(v: Int): Expression = Expressions.literal(Int.box(v)) + private def litLong(v: Long): Expression = Expressions.literal(Long.box(v)) + private def litStr(v: String): Expression = + Expressions.literal(org.apache.spark.unsafe.types.UTF8String.fromString(v)) + + test("LessThan(timeline, 1_000_000) translates to a non-empty proto") { + val p = new Predicate("<", Array[Expression](col("timeline"), litLong(1000000L))) + val node = SparkPredicateTranslator.translate(p).getOrElse(fail("expected Some")) + val bytes = node.toByteArray + assert(bytes.nonEmpty) + val parsed = LogicalExprNode.parseFrom(bytes) + assert(parsed.hasBinaryExpr) + assert(parsed.getBinaryExpr.getOp == "Lt") + } + + test("AND of two translatable predicates round-trips through binary op 'And'") { + val lt = new Predicate("<", Array[Expression](col("a"), litInt(10))) + val eq = new Predicate("=", Array[Expression](col("b"), litStr("x"))) + val and = new Predicate("AND", Array[Expression](lt, eq)) + val node = SparkPredicateTranslator.translate(and).getOrElse(fail("expected Some")) + val parsed = LogicalExprNode.parseFrom(node.toByteArray) + assert(parsed.hasBinaryExpr) + assert(parsed.getBinaryExpr.getOp == "And") + } + + test("AND becomes residual when an operand is untranslatable") { + val nse = new Predicate("<=>", Array[Expression](col("a"), litInt(1))) + val eq = new Predicate("=", Array[Expression](col("b"), litInt(2))) + val and = new Predicate("AND", Array[Expression](nse, eq)) + assert(SparkPredicateTranslator.translate(and).isEmpty) + } + + test("IS_NULL and IS_NOT_NULL emit the dedicated proto variants") { + val isNull = new Predicate("IS_NULL", Array[Expression](col("x"))) + val isNotNull = new Predicate("IS_NOT_NULL", Array[Expression](col("x"))) + val n1 = SparkPredicateTranslator.translate(isNull).getOrElse(fail()).toByteArray + val n2 = SparkPredicateTranslator.translate(isNotNull).getOrElse(fail()).toByteArray + val p1 = LogicalExprNode.parseFrom(n1) + val p2 = LogicalExprNode.parseFrom(n2) + assert(p1.hasIsNullExpr) + assert(p2.hasIsNotNullExpr) + } + + test("STARTS_WITH translates to a LIKE with a '%' suffix") { + val p = + new Predicate("STARTS_WITH", Array[Expression](col("name"), litStr("foo"))) + val node = SparkPredicateTranslator.translate(p).getOrElse(fail()) + val parsed = LogicalExprNode.parseFrom(node.toByteArray) + assert(parsed.hasLike) + val patStr = parsed.getLike.getPattern.getLiteral.getUtf8Value + assert(patStr == "foo%") + } + + test("unknown predicate name returns None (becomes residual)") { + val p = new Predicate("UNKNOWN_OP", Array[Expression](col("x"), litInt(1))) + assert(SparkPredicateTranslator.translate(p).isEmpty) + } +}